feat: implement ntfy-discord bridge in Go
- SSE subscription to ntfy with auto-reconnect - Discord webhook integration with embed formatting - Priority to color mapping, tag to emoji conversion - Native HashiCorp Vault support (Kubernetes + token auth) - Hot reload secrets via fsnotify or Vault polling - Prometheus metrics (/metrics endpoint) - Health/ready endpoints for Kubernetes probes - Comprehensive unit tests and fuzz tests - Multi-stage Docker build (~10MB scratch image) - CI/CD pipeline for Gitea Actions
This commit is contained in:
116
internal/bridge/bridge.go
Normal file
116
internal/bridge/bridge.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/ntfy-discord/internal/config"
|
||||
"git.daviestechlabs.io/daviestechlabs/ntfy-discord/internal/discord"
|
||||
"git.daviestechlabs.io/daviestechlabs/ntfy-discord/internal/ntfy"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
var (
|
||||
messagesReceived = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ntfy_discord_messages_received_total",
|
||||
Help: "Total number of messages received from ntfy",
|
||||
}, []string{"topic"})
|
||||
|
||||
messagesSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ntfy_discord_messages_sent_total",
|
||||
Help: "Total number of messages sent to Discord",
|
||||
}, []string{"topic"})
|
||||
|
||||
messagesErrors = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ntfy_discord_messages_errors_total",
|
||||
Help: "Total number of errors sending to Discord",
|
||||
}, []string{"topic"})
|
||||
)
|
||||
|
||||
// Bridge connects ntfy to Discord
|
||||
type Bridge struct {
|
||||
cfg *config.Config
|
||||
ntfyClient *ntfy.Client
|
||||
discordClient *discord.Client
|
||||
ready atomic.Bool
|
||||
}
|
||||
|
||||
// New creates a new Bridge
|
||||
func New(cfg *config.Config) *Bridge {
|
||||
return &Bridge{
|
||||
cfg: cfg,
|
||||
ntfyClient: ntfy.NewClient(cfg.NtfyURL, cfg.NtfyTopics),
|
||||
discordClient: discord.NewClient(),
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the bridge
|
||||
func (b *Bridge) Run(ctx context.Context) {
|
||||
if len(b.cfg.NtfyTopics) == 0 {
|
||||
slog.Error("no topics configured")
|
||||
return
|
||||
}
|
||||
|
||||
msgCh := make(chan ntfy.Message, 100)
|
||||
|
||||
// Start SSE subscription
|
||||
go b.ntfyClient.Subscribe(ctx, msgCh)
|
||||
|
||||
// Mark as ready once we start processing
|
||||
b.ready.Store(true)
|
||||
|
||||
// Process messages
|
||||
for {
|
||||
select {
|
||||
case msg := <-msgCh:
|
||||
b.handleMessage(ctx, msg)
|
||||
case <-ctx.Done():
|
||||
b.ready.Store(false)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsReady returns true if the bridge is ready to process messages
|
||||
func (b *Bridge) IsReady() bool {
|
||||
return b.ready.Load()
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the bridge is healthy
|
||||
func (b *Bridge) IsHealthy() bool {
|
||||
// For now, healthy == ready
|
||||
// Could add more checks like webhook URL configured
|
||||
return b.cfg.WebhookURL() != ""
|
||||
}
|
||||
|
||||
func (b *Bridge) handleMessage(ctx context.Context, msg ntfy.Message) {
|
||||
messagesReceived.WithLabelValues(msg.Topic).Inc()
|
||||
|
||||
webhookURL := b.cfg.WebhookURL()
|
||||
if webhookURL == "" {
|
||||
slog.Warn("no webhook URL configured, dropping message", "topic", msg.Topic)
|
||||
messagesErrors.WithLabelValues(msg.Topic).Inc()
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("forwarding message to Discord",
|
||||
"id", msg.ID,
|
||||
"topic", msg.Topic,
|
||||
"title", msg.Title,
|
||||
)
|
||||
|
||||
if err := b.discordClient.Send(ctx, webhookURL, msg); err != nil {
|
||||
slog.Error("failed to send to Discord",
|
||||
"error", err,
|
||||
"topic", msg.Topic,
|
||||
"id", msg.ID,
|
||||
)
|
||||
messagesErrors.WithLabelValues(msg.Topic).Inc()
|
||||
return
|
||||
}
|
||||
|
||||
messagesSent.WithLabelValues(msg.Topic).Inc()
|
||||
slog.Debug("message sent to Discord", "id", msg.ID)
|
||||
}
|
||||
115
internal/bridge/bridge_test.go
Normal file
115
internal/bridge/bridge_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/ntfy-discord/internal/ntfy"
|
||||
)
|
||||
|
||||
func TestBridge_IsReady(t *testing.T) {
|
||||
// Create a minimal bridge for testing ready state
|
||||
b := &Bridge{}
|
||||
|
||||
// Initially not ready
|
||||
if b.IsReady() {
|
||||
t.Error("bridge should not be ready before Run()")
|
||||
}
|
||||
|
||||
// Set ready
|
||||
b.ready.Store(true)
|
||||
if !b.IsReady() {
|
||||
t.Error("bridge should be ready after ready.Store(true)")
|
||||
}
|
||||
|
||||
// Unset ready
|
||||
b.ready.Store(false)
|
||||
if b.IsReady() {
|
||||
t.Error("bridge should not be ready after ready.Store(false)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsRegistered(t *testing.T) {
|
||||
// Verify metrics are registered by checking they're not nil
|
||||
if messagesReceived == nil {
|
||||
t.Error("messagesReceived metric not registered")
|
||||
}
|
||||
if messagesSent == nil {
|
||||
t.Error("messagesSent metric not registered")
|
||||
}
|
||||
if messagesErrors == nil {
|
||||
t.Error("messagesErrors metric not registered")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Bridge correctly uses the ready atomic
|
||||
func TestBridge_ReadyState_Concurrent(t *testing.T) {
|
||||
b := &Bridge{}
|
||||
|
||||
// Test concurrent access
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b.ready.Store(true)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = b.IsReady()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Should not have any race conditions
|
||||
}
|
||||
|
||||
// Test message channel buffer
|
||||
func TestMessageChannelBuffer(t *testing.T) {
|
||||
msgCh := make(chan ntfy.Message, 100)
|
||||
|
||||
// Should be able to buffer 100 messages without blocking
|
||||
for i := 0; i < 100; i++ {
|
||||
select {
|
||||
case msgCh <- ntfy.Message{ID: "test"}:
|
||||
// OK
|
||||
default:
|
||||
t.Fatalf("channel blocked at message %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// 101st should block (use select to avoid blocking test)
|
||||
select {
|
||||
case msgCh <- ntfy.Message{ID: "overflow"}:
|
||||
t.Error("channel should be full")
|
||||
default:
|
||||
// Expected - channel full
|
||||
}
|
||||
}
|
||||
|
||||
// Test context cancellation behavior
|
||||
func TestContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// Simulate Run() loop behavior
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
close(done)
|
||||
case <-time.After(5 * time.Second):
|
||||
// Should not reach here
|
||||
}
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(time.Second):
|
||||
t.Error("context cancellation not handled")
|
||||
}
|
||||
}
|
||||
233
internal/config/config.go
Normal file
233
internal/config/config.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/ntfy-discord/internal/vault"
|
||||
)
|
||||
|
||||
// Config holds the application configuration
|
||||
type Config struct {
|
||||
NtfyURL string
|
||||
NtfyTopics []string
|
||||
SecretsPath string
|
||||
HTTPPort string
|
||||
|
||||
// Vault configuration
|
||||
VaultEnabled bool
|
||||
vaultClient *vault.Client
|
||||
|
||||
mu sync.RWMutex
|
||||
webhookURL string
|
||||
}
|
||||
|
||||
// Load creates a new Config from environment variables
|
||||
func Load(ctx context.Context) (*Config, error) {
|
||||
cfg := &Config{
|
||||
NtfyURL: getEnv("NTFY_URL", "http://ntfy.observability.svc.cluster.local"),
|
||||
SecretsPath: getEnv("SECRETS_PATH", ""),
|
||||
HTTPPort: getEnv("HTTP_PORT", "8080"),
|
||||
VaultEnabled: getEnv("VAULT_ENABLED", "false") == "true",
|
||||
}
|
||||
|
||||
// Parse topics
|
||||
topics := getEnv("NTFY_TOPICS", "")
|
||||
if topics != "" {
|
||||
cfg.NtfyTopics = strings.Split(topics, ",")
|
||||
for i := range cfg.NtfyTopics {
|
||||
cfg.NtfyTopics[i] = strings.TrimSpace(cfg.NtfyTopics[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Try Vault first if enabled
|
||||
if cfg.VaultEnabled {
|
||||
if err := cfg.initVault(ctx); err != nil {
|
||||
slog.Warn("vault init failed, falling back to file/env", "error", err)
|
||||
} else {
|
||||
// Load webhook from Vault
|
||||
if webhookURL, err := cfg.vaultClient.GetSecret(ctx, "webhook-url"); err == nil {
|
||||
cfg.mu.Lock()
|
||||
cfg.webhookURL = webhookURL
|
||||
cfg.mu.Unlock()
|
||||
slog.Info("loaded webhook URL from vault")
|
||||
} else {
|
||||
slog.Warn("failed to get webhook from vault", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to file-based secret
|
||||
if cfg.webhookURL == "" && cfg.SecretsPath != "" {
|
||||
if err := cfg.loadWebhookFromSecret(); err != nil {
|
||||
slog.Warn("failed to load webhook from secret, trying env", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
if cfg.webhookURL == "" {
|
||||
cfg.webhookURL = getEnv("DISCORD_WEBHOOK_URL", "")
|
||||
}
|
||||
|
||||
if cfg.webhookURL == "" {
|
||||
slog.Warn("no Discord webhook URL configured")
|
||||
}
|
||||
|
||||
slog.Info("config loaded",
|
||||
"ntfy_url", cfg.NtfyURL,
|
||||
"topics", cfg.NtfyTopics,
|
||||
"secrets_path", cfg.SecretsPath,
|
||||
"vault_enabled", cfg.VaultEnabled,
|
||||
)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// initVault initializes the Vault client
|
||||
func (c *Config) initVault(ctx context.Context) error {
|
||||
vaultCfg := vault.Config{
|
||||
Address: getEnv("VAULT_ADDR", "http://vault.vault.svc.cluster.local:8200"),
|
||||
AuthMethod: getEnv("VAULT_AUTH_METHOD", "kubernetes"),
|
||||
Role: getEnv("VAULT_ROLE", "ntfy-discord"),
|
||||
MountPath: getEnv("VAULT_MOUNT_PATH", "secret"),
|
||||
SecretPath: getEnv("VAULT_SECRET_PATH", "ntfy-discord"),
|
||||
TokenPath: getEnv("VAULT_TOKEN_PATH", ""),
|
||||
}
|
||||
|
||||
client, err := vault.NewClient(vaultCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.vaultClient = client
|
||||
return nil
|
||||
}
|
||||
|
||||
// WebhookURL returns the current webhook URL (thread-safe)
|
||||
func (c *Config) WebhookURL() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.webhookURL
|
||||
}
|
||||
|
||||
// loadWebhookFromSecret reads the webhook URL from the mounted secret
|
||||
func (c *Config) loadWebhookFromSecret() error {
|
||||
path := filepath.Join(c.SecretsPath, "webhook-url")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.webhookURL = strings.TrimSpace(string(data))
|
||||
c.mu.Unlock()
|
||||
|
||||
slog.Debug("loaded webhook URL from secret")
|
||||
return nil
|
||||
}
|
||||
|
||||
// WatchSecrets watches the secrets directory for changes and reloads
|
||||
func (c *Config) WatchSecrets(ctx context.Context) {
|
||||
// Start Vault watcher if enabled
|
||||
if c.vaultClient != nil {
|
||||
go c.watchVaultSecrets(ctx)
|
||||
}
|
||||
|
||||
// Start file watcher if secrets path is set
|
||||
if c.SecretsPath != "" {
|
||||
go c.watchFileSecrets(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// watchVaultSecrets periodically refreshes secrets from Vault
|
||||
func (c *Config) watchVaultSecrets(ctx context.Context) {
|
||||
interval := 30 * time.Second
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
slog.Info("watching vault secrets", "interval", interval)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if webhookURL, err := c.vaultClient.GetSecret(ctx, "webhook-url"); err == nil {
|
||||
c.mu.Lock()
|
||||
if c.webhookURL != webhookURL {
|
||||
c.webhookURL = webhookURL
|
||||
slog.Info("webhook URL updated from vault")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
} else {
|
||||
slog.Error("failed to refresh webhook from vault", "error", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// watchFileSecrets watches the secrets directory for changes
|
||||
func (c *Config) watchFileSecrets(ctx context.Context) {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
slog.Error("failed to create fsnotify watcher", "error", err)
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
// Watch the secrets directory
|
||||
// Kubernetes updates secrets by changing the symlink, so watch the parent
|
||||
if err := watcher.Add(c.SecretsPath); err != nil {
|
||||
slog.Error("failed to watch secrets path", "error", err, "path", c.SecretsPath)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("watching secrets for changes", "path", c.SecretsPath)
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Kubernetes updates secrets via symlink, which triggers Create events
|
||||
if event.Has(fsnotify.Create) || event.Has(fsnotify.Write) {
|
||||
slog.Info("secret changed, reloading", "event", event.Name)
|
||||
if err := c.loadWebhookFromSecret(); err != nil {
|
||||
slog.Error("failed to reload webhook from secret", "error", err)
|
||||
} else {
|
||||
slog.Info("webhook URL reloaded successfully")
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
slog.Error("fsnotify error", "error", err)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close cleans up resources
|
||||
func (c *Config) Close() error {
|
||||
if c.vaultClient != nil {
|
||||
return c.vaultClient.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getEnv(key, defaultVal string) string {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
return val
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
223
internal/config/config_test.go
Normal file
223
internal/config/config_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetEnv(t *testing.T) {
|
||||
// Set a test env var
|
||||
os.Setenv("TEST_CONFIG_VAR", "test_value")
|
||||
defer os.Unsetenv("TEST_CONFIG_VAR")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
defaultVal string
|
||||
want string
|
||||
}{
|
||||
{"existing var", "TEST_CONFIG_VAR", "default", "test_value"},
|
||||
{"non-existing var", "NON_EXISTING_VAR", "default", "default"},
|
||||
{"empty default", "NON_EXISTING_VAR_2", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := getEnv(tt.key, tt.defaultVal)
|
||||
if got != tt.want {
|
||||
t.Errorf("getEnv(%s, %s) = %s, want %s", tt.key, tt.defaultVal, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_WebhookURL_ThreadSafe(t *testing.T) {
|
||||
cfg := &Config{}
|
||||
cfg.webhookURL = "https://discord.com/api/webhooks/test"
|
||||
|
||||
// Test concurrent reads and writes
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = cfg.WebhookURL()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cfg.mu.Lock()
|
||||
cfg.webhookURL = "updated"
|
||||
cfg.mu.Unlock()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Should not race
|
||||
}
|
||||
|
||||
func TestConfig_LoadWebhookFromSecret(t *testing.T) {
|
||||
// Create temp directory with secret file
|
||||
tmpDir := t.TempDir()
|
||||
secretPath := filepath.Join(tmpDir, "webhook-url")
|
||||
webhookURL := "https://discord.com/api/webhooks/123/abc"
|
||||
|
||||
if err := os.WriteFile(secretPath, []byte(webhookURL+"\n"), 0644); err != nil {
|
||||
t.Fatalf("failed to write secret file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
SecretsPath: tmpDir,
|
||||
}
|
||||
|
||||
if err := cfg.loadWebhookFromSecret(); err != nil {
|
||||
t.Fatalf("loadWebhookFromSecret() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.webhookURL != webhookURL {
|
||||
t.Errorf("webhookURL = %s, want %s", cfg.webhookURL, webhookURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_LoadWebhookFromSecret_NotFound(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SecretsPath: "/nonexistent/path",
|
||||
}
|
||||
|
||||
err := cfg.loadWebhookFromSecret()
|
||||
if err == nil {
|
||||
t.Error("expected error for non-existent secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_LoadWebhookFromSecret_TrimsWhitespace(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
secretPath := filepath.Join(tmpDir, "webhook-url")
|
||||
|
||||
// Write with extra whitespace
|
||||
if err := os.WriteFile(secretPath, []byte(" https://example.com \n\n"), 0644); err != nil {
|
||||
t.Fatalf("failed to write secret file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &Config{SecretsPath: tmpDir}
|
||||
if err := cfg.loadWebhookFromSecret(); err != nil {
|
||||
t.Fatalf("loadWebhookFromSecret() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.webhookURL != "https://example.com" {
|
||||
t.Errorf("webhookURL = %q, want %q", cfg.webhookURL, "https://example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_ParsesTopics(t *testing.T) {
|
||||
os.Setenv("NTFY_TOPICS", "alerts, updates , notifications")
|
||||
os.Setenv("VAULT_ENABLED", "false")
|
||||
defer func() {
|
||||
os.Unsetenv("NTFY_TOPICS")
|
||||
os.Unsetenv("VAULT_ENABLED")
|
||||
}()
|
||||
|
||||
cfg, err := Load(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
expected := []string{"alerts", "updates", "notifications"}
|
||||
if len(cfg.NtfyTopics) != len(expected) {
|
||||
t.Errorf("NtfyTopics length = %d, want %d", len(cfg.NtfyTopics), len(expected))
|
||||
}
|
||||
|
||||
for i, topic := range cfg.NtfyTopics {
|
||||
if topic != expected[i] {
|
||||
t.Errorf("NtfyTopics[%d] = %s, want %s", i, topic, expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_Defaults(t *testing.T) {
|
||||
// Clear any existing env vars
|
||||
os.Unsetenv("NTFY_URL")
|
||||
os.Unsetenv("HTTP_PORT")
|
||||
os.Unsetenv("VAULT_ENABLED")
|
||||
|
||||
cfg, err := Load(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.NtfyURL != "http://ntfy.observability.svc.cluster.local" {
|
||||
t.Errorf("NtfyURL = %s, want default", cfg.NtfyURL)
|
||||
}
|
||||
|
||||
if cfg.HTTPPort != "8080" {
|
||||
t.Errorf("HTTPPort = %s, want 8080", cfg.HTTPPort)
|
||||
}
|
||||
|
||||
if cfg.VaultEnabled {
|
||||
t.Error("VaultEnabled should be false by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_VaultEnabled(t *testing.T) {
|
||||
os.Setenv("VAULT_ENABLED", "true")
|
||||
defer os.Unsetenv("VAULT_ENABLED")
|
||||
|
||||
// This will fail to init Vault (no server), but should gracefully fall back
|
||||
cfg, err := Load(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
if !cfg.VaultEnabled {
|
||||
t.Error("VaultEnabled should be true")
|
||||
}
|
||||
|
||||
// vaultClient should be nil (failed to connect)
|
||||
if cfg.vaultClient != nil {
|
||||
t.Error("vaultClient should be nil when Vault unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_FallsBackToEnvVar(t *testing.T) {
|
||||
webhookURL := "https://discord.com/api/webhooks/env/test"
|
||||
os.Setenv("DISCORD_WEBHOOK_URL", webhookURL)
|
||||
os.Setenv("VAULT_ENABLED", "false")
|
||||
defer func() {
|
||||
os.Unsetenv("DISCORD_WEBHOOK_URL")
|
||||
os.Unsetenv("VAULT_ENABLED")
|
||||
}()
|
||||
|
||||
cfg, err := Load(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.WebhookURL() != webhookURL {
|
||||
t.Errorf("WebhookURL() = %s, want %s", cfg.WebhookURL(), webhookURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Close_NoVault(t *testing.T) {
|
||||
cfg := &Config{}
|
||||
|
||||
// Should not panic with nil vaultClient
|
||||
err := cfg.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_WatchSecrets_NoPath(t *testing.T) {
|
||||
cfg := &Config{
|
||||
SecretsPath: "",
|
||||
vaultClient: nil,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Should return immediately, not block or panic
|
||||
cfg.WatchSecrets(ctx)
|
||||
}
|
||||
72
internal/config/fuzz_test.go
Normal file
72
internal/config/fuzz_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// FuzzParseTopics tests topic parsing doesn't panic on arbitrary input
|
||||
func FuzzParseTopics(f *testing.F) {
|
||||
// Normal cases
|
||||
f.Add("alerts")
|
||||
f.Add("alerts,updates")
|
||||
f.Add("alerts, updates, notifications")
|
||||
f.Add("")
|
||||
|
||||
// Edge cases
|
||||
f.Add(",,,")
|
||||
f.Add(" , , ")
|
||||
f.Add("a]topic-with-special_chars.123")
|
||||
f.Add("\x00\x00\x00")
|
||||
f.Add("topic\nwith\nnewlines")
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
// Simulate topic parsing logic
|
||||
if input == "" {
|
||||
return
|
||||
}
|
||||
|
||||
topics := make([]string, 0)
|
||||
for _, topic := range splitAndTrim(input) {
|
||||
topics = append(topics, topic)
|
||||
}
|
||||
|
||||
// Accessing results should not panic
|
||||
_ = len(topics)
|
||||
})
|
||||
}
|
||||
|
||||
// splitAndTrim mimics the topic parsing in Load()
|
||||
func splitAndTrim(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
var result []string
|
||||
start := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == ',' {
|
||||
part := trimSpace(s[start:i])
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
// Last part
|
||||
part := trimSpace(s[start:])
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func trimSpace(s string) string {
|
||||
start := 0
|
||||
end := len(s)
|
||||
for start < end && (s[start] == ' ' || s[start] == '\t' || s[start] == '\n' || s[start] == '\r') {
|
||||
start++
|
||||
}
|
||||
for end > start && (s[end-1] == ' ' || s[end-1] == '\t' || s[end-1] == '\n' || s[end-1] == '\r') {
|
||||
end--
|
||||
}
|
||||
return s[start:end]
|
||||
}
|
||||
125
internal/discord/fuzz_test.go
Normal file
125
internal/discord/fuzz_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/ntfy-discord/internal/ntfy"
|
||||
)
|
||||
|
||||
// FuzzBuildEmbed tests that embed building doesn't panic on arbitrary input
|
||||
func FuzzBuildEmbed(f *testing.F) {
|
||||
// Seed with normal inputs
|
||||
f.Add("Test Title", "Test message body", 3, "warning", "test-topic", "https://example.com")
|
||||
f.Add("", "", 0, "", "", "")
|
||||
f.Add("Alert!", "Critical issue detected", 5, "fire", "alerts", "")
|
||||
|
||||
// Edge cases
|
||||
f.Add("x", "y", -1, "unknown", "t", "not-a-url")
|
||||
f.Add("x", "y", 100, "", "", "")
|
||||
f.Add("🔥 Fire Alert", "💀 Something broke", 5, "skull", "test", "")
|
||||
|
||||
// Long strings
|
||||
longStr := ""
|
||||
for i := 0; i < 10000; i++ {
|
||||
longStr += "x"
|
||||
}
|
||||
f.Add(longStr, longStr, 3, "tag", "topic", "https://example.com")
|
||||
|
||||
// Special characters
|
||||
f.Add("Title\x00with\x00nulls", "Message\nwith\nnewlines", 3, "tag", "topic", "")
|
||||
f.Add("<script>alert('xss')</script>", "```code```", 3, "", "", "")
|
||||
f.Add("Title\t\r\n", "Body\t\r\n", 3, "", "", "")
|
||||
|
||||
f.Fuzz(func(t *testing.T, title, message string, priority int, tag, topic, click string) {
|
||||
msg := ntfy.Message{
|
||||
Title: title,
|
||||
Message: message,
|
||||
Priority: priority,
|
||||
Tags: []string{tag},
|
||||
Topic: topic,
|
||||
Click: click,
|
||||
Time: 1706803200,
|
||||
}
|
||||
|
||||
client := NewClient()
|
||||
|
||||
// Should never panic
|
||||
embed := client.buildEmbed(msg)
|
||||
|
||||
// Resulting embed should be valid
|
||||
if embed.Footer == nil {
|
||||
t.Error("Footer should never be nil")
|
||||
}
|
||||
|
||||
// Color should always be set to a valid value
|
||||
if embed.Color == 0 {
|
||||
t.Error("Color should never be 0")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzExtractEmoji tests emoji extraction doesn't panic
|
||||
func FuzzExtractEmoji(f *testing.F) {
|
||||
// Valid tags
|
||||
f.Add("warning")
|
||||
f.Add("fire")
|
||||
f.Add("check")
|
||||
f.Add("rocket")
|
||||
|
||||
// Edge cases
|
||||
f.Add("")
|
||||
f.Add("unknown_tag")
|
||||
f.Add("WARNING") // uppercase
|
||||
f.Add("WaRnInG") // mixed case
|
||||
f.Add("\x00")
|
||||
f.Add("tag with spaces")
|
||||
f.Add("émoji")
|
||||
f.Add("🔥") // emoji as tag
|
||||
f.Add("a]b[c{d}") // special chars
|
||||
|
||||
f.Fuzz(func(t *testing.T, tag string) {
|
||||
client := NewClient()
|
||||
|
||||
// Should never panic
|
||||
tags := []string{tag}
|
||||
_ = client.extractEmoji(tags)
|
||||
|
||||
// Multiple tags
|
||||
_ = client.extractEmoji([]string{tag, tag, tag})
|
||||
|
||||
// Empty slice
|
||||
_ = client.extractEmoji([]string{})
|
||||
|
||||
// Nil slice
|
||||
_ = client.extractEmoji(nil)
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzWebhookPayloadJSON tests JSON marshaling of payloads
|
||||
func FuzzWebhookPayloadJSON(f *testing.F) {
|
||||
f.Add("Title", "Description", 3066993, "Topic", "value", "footer")
|
||||
|
||||
f.Fuzz(func(t *testing.T, title, desc string, color int, fieldName, fieldValue, footer string) {
|
||||
payload := WebhookPayload{
|
||||
Embeds: []Embed{
|
||||
{
|
||||
Title: title,
|
||||
Description: desc,
|
||||
Color: color,
|
||||
Fields: []Field{
|
||||
{Name: fieldName, Value: fieldValue, Inline: true},
|
||||
},
|
||||
Footer: &Footer{Text: footer},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Marshaling should not panic
|
||||
_, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
// JSON encoding errors are acceptable for invalid UTF-8
|
||||
// but should not panic
|
||||
}
|
||||
})
|
||||
}
|
||||
203
internal/discord/webhook.go
Normal file
203
internal/discord/webhook.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/ntfy-discord/internal/ntfy"
|
||||
)
|
||||
|
||||
// Priority to Discord embed color mapping
|
||||
var priorityColors = map[int]int{
|
||||
5: 15158332, // Red - Max/Urgent
|
||||
4: 15105570, // Orange - High
|
||||
3: 3066993, // Blue - Default
|
||||
2: 9807270, // Gray - Low
|
||||
1: 12370112, // Light Gray - Min
|
||||
}
|
||||
|
||||
// Tag to emoji mapping
|
||||
var tagEmojis = map[string]string{
|
||||
"white_check_mark": "✅",
|
||||
"heavy_check_mark": "✅",
|
||||
"check": "✅",
|
||||
"x": "❌",
|
||||
"skull": "❌",
|
||||
"warning": "⚠️",
|
||||
"rotating_light": "🚨",
|
||||
"rocket": "🚀",
|
||||
"package": "📦",
|
||||
"tada": "🎉",
|
||||
"fire": "🔥",
|
||||
"bug": "🐛",
|
||||
"wrench": "🔧",
|
||||
"gear": "⚙️",
|
||||
"lock": "🔒",
|
||||
"key": "🔑",
|
||||
"bell": "🔔",
|
||||
"mega": "📢",
|
||||
"eyes": "👀",
|
||||
"sos": "🆘",
|
||||
"no_entry": "⛔",
|
||||
"construction": "🚧",
|
||||
}
|
||||
|
||||
// Embed represents a Discord embed
|
||||
type Embed struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Color int `json:"color,omitempty"`
|
||||
Fields []Field `json:"fields,omitempty"`
|
||||
Timestamp string `json:"timestamp,omitempty"`
|
||||
Footer *Footer `json:"footer,omitempty"`
|
||||
}
|
||||
|
||||
// Field represents a Discord embed field
|
||||
type Field struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Inline bool `json:"inline,omitempty"`
|
||||
}
|
||||
|
||||
// Footer represents a Discord embed footer
|
||||
type Footer struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// WebhookPayload is the Discord webhook request body
|
||||
type WebhookPayload struct {
|
||||
Embeds []Embed `json:"embeds"`
|
||||
}
|
||||
|
||||
// Client sends messages to Discord webhooks
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new Discord webhook client
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Send converts an ntfy message to Discord format and sends it
|
||||
func (c *Client) Send(ctx context.Context, webhookURL string, msg ntfy.Message) error {
|
||||
if webhookURL == "" {
|
||||
return fmt.Errorf("no webhook URL configured")
|
||||
}
|
||||
|
||||
embed := c.buildEmbed(msg)
|
||||
payload := WebhookPayload{
|
||||
Embeds: []Embed{embed},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, webhookURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Handle rate limiting
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
retryAfter := resp.Header.Get("Retry-After")
|
||||
if retryAfter != "" {
|
||||
if seconds, err := strconv.Atoi(retryAfter); err == nil {
|
||||
slog.Warn("Discord rate limited", "retry_after", seconds)
|
||||
time.Sleep(time.Duration(seconds) * time.Second)
|
||||
return c.Send(ctx, webhookURL, msg) // Retry
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("rate limited")
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) buildEmbed(msg ntfy.Message) Embed {
|
||||
// Get color from priority
|
||||
color := priorityColors[3] // Default
|
||||
if msg.Priority >= 1 && msg.Priority <= 5 {
|
||||
color = priorityColors[msg.Priority]
|
||||
}
|
||||
|
||||
// Build title with emoji prefix
|
||||
title := msg.Title
|
||||
if title == "" {
|
||||
title = msg.Topic
|
||||
}
|
||||
emoji := c.extractEmoji(msg.Tags)
|
||||
if emoji != "" {
|
||||
title = emoji + " " + title
|
||||
}
|
||||
|
||||
// Build timestamp
|
||||
timestamp := ""
|
||||
if msg.Time > 0 {
|
||||
timestamp = time.Unix(msg.Time, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
embed := Embed{
|
||||
Title: title,
|
||||
Description: msg.Message,
|
||||
Color: color,
|
||||
Timestamp: timestamp,
|
||||
Footer: &Footer{Text: "ntfy"},
|
||||
}
|
||||
|
||||
// Add topic as field
|
||||
if msg.Topic != "" {
|
||||
embed.Fields = append(embed.Fields, Field{
|
||||
Name: "Topic",
|
||||
Value: msg.Topic,
|
||||
Inline: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Add click URL if present
|
||||
if msg.Click != "" {
|
||||
embed.Fields = append(embed.Fields, Field{
|
||||
Name: "Link",
|
||||
Value: msg.Click,
|
||||
Inline: false,
|
||||
})
|
||||
}
|
||||
|
||||
return embed
|
||||
}
|
||||
|
||||
func (c *Client) extractEmoji(tags []string) string {
|
||||
for _, tag := range tags {
|
||||
tag = strings.ToLower(tag)
|
||||
if emoji, ok := tagEmojis[tag]; ok {
|
||||
return emoji
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
260
internal/discord/webhook_test.go
Normal file
260
internal/discord/webhook_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/ntfy-discord/internal/ntfy"
|
||||
)
|
||||
|
||||
func TestClient_Send(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg ntfy.Message
|
||||
wantStatus int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful send",
|
||||
msg: ntfy.Message{
|
||||
ID: "test-id",
|
||||
Topic: "alerts",
|
||||
Title: "Test Alert",
|
||||
Message: "This is a test message",
|
||||
Priority: 3,
|
||||
Time: time.Now().Unix(),
|
||||
},
|
||||
wantStatus: http.StatusNoContent,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "high priority message",
|
||||
msg: ntfy.Message{
|
||||
ID: "high-priority",
|
||||
Topic: "urgent",
|
||||
Title: "Urgent Alert",
|
||||
Message: "Critical issue detected",
|
||||
Priority: 5,
|
||||
Tags: []string{"warning", "fire"},
|
||||
},
|
||||
wantStatus: http.StatusNoContent,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "server error",
|
||||
msg: ntfy.Message{
|
||||
ID: "error-test",
|
||||
Topic: "test",
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var receivedPayload WebhookPayload
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("expected application/json, got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&receivedPayload); err != nil {
|
||||
t.Errorf("failed to decode payload: %v", err)
|
||||
}
|
||||
|
||||
w.WriteHeader(tt.wantStatus)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
err := client.Send(context.Background(), server.URL, tt.msg)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Send() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if !tt.wantErr && len(receivedPayload.Embeds) != 1 {
|
||||
t.Errorf("expected 1 embed, got %d", len(receivedPayload.Embeds))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Send_NoWebhookURL(t *testing.T) {
|
||||
client := NewClient()
|
||||
err := client.Send(context.Background(), "", ntfy.Message{})
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected error for empty webhook URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_buildEmbed(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
msg ntfy.Message
|
||||
wantColor int
|
||||
wantEmoji string
|
||||
}{
|
||||
{
|
||||
name: "default priority",
|
||||
msg: ntfy.Message{
|
||||
Title: "Test",
|
||||
Message: "Hello",
|
||||
Priority: 3,
|
||||
},
|
||||
wantColor: 3066993, // Blue
|
||||
wantEmoji: "",
|
||||
},
|
||||
{
|
||||
name: "max priority with warning tag",
|
||||
msg: ntfy.Message{
|
||||
Title: "Alert",
|
||||
Message: "Critical",
|
||||
Priority: 5,
|
||||
Tags: []string{"warning"},
|
||||
},
|
||||
wantColor: 15158332, // Red
|
||||
wantEmoji: "⚠️",
|
||||
},
|
||||
{
|
||||
name: "low priority",
|
||||
msg: ntfy.Message{
|
||||
Title: "Info",
|
||||
Message: "Low priority",
|
||||
Priority: 2,
|
||||
},
|
||||
wantColor: 9807270, // Gray
|
||||
wantEmoji: "",
|
||||
},
|
||||
{
|
||||
name: "with check tag",
|
||||
msg: ntfy.Message{
|
||||
Title: "Success",
|
||||
Message: "Completed",
|
||||
Priority: 3,
|
||||
Tags: []string{"check", "success"},
|
||||
},
|
||||
wantColor: 3066993,
|
||||
wantEmoji: "✅",
|
||||
},
|
||||
{
|
||||
name: "no title uses topic",
|
||||
msg: ntfy.Message{
|
||||
Topic: "alerts",
|
||||
Message: "No title",
|
||||
},
|
||||
wantColor: 3066993,
|
||||
wantEmoji: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
embed := client.buildEmbed(tt.msg)
|
||||
|
||||
if embed.Color != tt.wantColor {
|
||||
t.Errorf("Color = %d, want %d", embed.Color, tt.wantColor)
|
||||
}
|
||||
|
||||
if tt.wantEmoji != "" {
|
||||
if len(embed.Title) < 2 || embed.Title[:len(tt.wantEmoji)] != tt.wantEmoji {
|
||||
t.Errorf("Title should start with emoji %s, got %s", tt.wantEmoji, embed.Title)
|
||||
}
|
||||
}
|
||||
|
||||
if embed.Description != tt.msg.Message {
|
||||
t.Errorf("Description = %s, want %s", embed.Description, tt.msg.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_extractEmoji(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tags []string
|
||||
want string
|
||||
}{
|
||||
{"warning tag", []string{"warning"}, "⚠️"},
|
||||
{"check tag", []string{"check"}, "✅"},
|
||||
{"fire tag", []string{"fire"}, "🔥"},
|
||||
{"rocket tag", []string{"rocket"}, "🚀"},
|
||||
{"unknown tag", []string{"unknown"}, ""},
|
||||
{"empty tags", []string{}, ""},
|
||||
{"multiple tags first match", []string{"unknown", "fire", "warning"}, "🔥"},
|
||||
{"case insensitive", []string{"WARNING"}, "⚠️"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := client.extractEmoji(tt.tags)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractEmoji(%v) = %s, want %s", tt.tags, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPriorityColors(t *testing.T) {
|
||||
expected := map[int]int{
|
||||
1: 12370112, // Light Gray
|
||||
2: 9807270, // Gray
|
||||
3: 3066993, // Blue
|
||||
4: 15105570, // Orange
|
||||
5: 15158332, // Red
|
||||
}
|
||||
|
||||
for priority, color := range expected {
|
||||
if priorityColors[priority] != color {
|
||||
t.Errorf("priorityColors[%d] = %d, want %d", priority, priorityColors[priority], color)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookPayload_JSON(t *testing.T) {
|
||||
payload := WebhookPayload{
|
||||
Embeds: []Embed{
|
||||
{
|
||||
Title: "Test",
|
||||
Description: "Hello World",
|
||||
Color: 3066993,
|
||||
Fields: []Field{
|
||||
{Name: "Topic", Value: "alerts", Inline: true},
|
||||
},
|
||||
Footer: &Footer{Text: "ntfy"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
var decoded WebhookPayload
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if len(decoded.Embeds) != 1 {
|
||||
t.Errorf("expected 1 embed, got %d", len(decoded.Embeds))
|
||||
}
|
||||
|
||||
if decoded.Embeds[0].Title != "Test" {
|
||||
t.Errorf("Title = %s, want Test", decoded.Embeds[0].Title)
|
||||
}
|
||||
}
|
||||
137
internal/ntfy/client.go
Normal file
137
internal/ntfy/client.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package ntfy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Message represents an ntfy message
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Time int64 `json:"time"`
|
||||
Expires int64 `json:"expires,omitempty"`
|
||||
Event string `json:"event"`
|
||||
Topic string `json:"topic"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Click string `json:"click,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
}
|
||||
|
||||
// Client subscribes to ntfy topics via SSE
|
||||
type Client struct {
|
||||
baseURL string
|
||||
topics []string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new ntfy SSE client
|
||||
func NewClient(baseURL string, topics []string) *Client {
|
||||
return &Client{
|
||||
baseURL: strings.TrimSuffix(baseURL, "/"),
|
||||
topics: topics,
|
||||
client: &http.Client{
|
||||
Timeout: 0, // No timeout for SSE
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe connects to ntfy and streams messages to the channel
|
||||
// It automatically reconnects with exponential backoff on failure
|
||||
func (c *Client) Subscribe(ctx context.Context, msgCh chan<- Message) {
|
||||
backoff := time.Second
|
||||
maxBackoff := time.Minute
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
err := c.connect(ctx, msgCh)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return // Context cancelled
|
||||
}
|
||||
slog.Error("ntfy connection failed", "error", err, "backoff", backoff)
|
||||
time.Sleep(backoff)
|
||||
backoff = min(backoff*2, maxBackoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// Reset backoff on successful connection
|
||||
backoff = time.Second
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context, msgCh chan<- Message) error {
|
||||
// Build URL with all topics
|
||||
topicPath := strings.Join(c.topics, ",")
|
||||
url := fmt.Sprintf("%s/%s/json", c.baseURL, topicPath)
|
||||
|
||||
slog.Info("connecting to ntfy", "url", url)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
slog.Info("connected to ntfy", "topics", c.topics)
|
||||
|
||||
// Read line by line (ntfy sends newline-delimited JSON)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var msg Message
|
||||
if err := json.Unmarshal([]byte(line), &msg); err != nil {
|
||||
slog.Warn("failed to parse ntfy message", "error", err, "line", line)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip keepalive/open events
|
||||
if msg.Event == "keepalive" || msg.Event == "open" {
|
||||
slog.Debug("ntfy event", "event", msg.Event)
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.Event == "message" {
|
||||
slog.Debug("received ntfy message",
|
||||
"id", msg.ID,
|
||||
"topic", msg.Topic,
|
||||
"title", msg.Title,
|
||||
)
|
||||
msgCh <- msg
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("read stream: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("stream closed")
|
||||
}
|
||||
227
internal/ntfy/client_test.go
Normal file
227
internal/ntfy/client_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package ntfy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
client := NewClient("http://ntfy.example.com", []string{"alerts", "updates"})
|
||||
|
||||
if client.baseURL != "http://ntfy.example.com" {
|
||||
t.Errorf("baseURL = %s, want http://ntfy.example.com", client.baseURL)
|
||||
}
|
||||
|
||||
if len(client.topics) != 2 {
|
||||
t.Errorf("topics length = %d, want 2", len(client.topics))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_TrimsTrailingSlash(t *testing.T) {
|
||||
client := NewClient("http://ntfy.example.com/", []string{"test"})
|
||||
|
||||
if client.baseURL != "http://ntfy.example.com" {
|
||||
t.Errorf("baseURL = %s, want http://ntfy.example.com (trailing slash removed)", client.baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_JSON(t *testing.T) {
|
||||
msg := Message{
|
||||
ID: "test-123",
|
||||
Time: 1706803200,
|
||||
Event: "message",
|
||||
Topic: "alerts",
|
||||
Title: "Test Alert",
|
||||
Message: "This is a test",
|
||||
Priority: 4,
|
||||
Tags: []string{"warning", "fire"},
|
||||
Click: "https://example.com",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
var decoded Message
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if decoded.ID != msg.ID {
|
||||
t.Errorf("ID = %s, want %s", decoded.ID, msg.ID)
|
||||
}
|
||||
|
||||
if decoded.Priority != msg.Priority {
|
||||
t.Errorf("Priority = %d, want %d", decoded.Priority, msg.Priority)
|
||||
}
|
||||
|
||||
if len(decoded.Tags) != 2 {
|
||||
t.Errorf("Tags length = %d, want 2", len(decoded.Tags))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Subscribe_ReceivesMessages(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Event: "open"},
|
||||
{Event: "message", ID: "msg1", Topic: "test", Title: "First", Message: "Hello"},
|
||||
{Event: "keepalive"},
|
||||
{Event: "message", ID: "msg2", Topic: "test", Title: "Second", Message: "World"},
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
for _, msg := range messages {
|
||||
data, _ := json.Marshal(msg)
|
||||
fmt.Fprintf(w, "%s\n", data)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, []string{"test"})
|
||||
msgCh := make(chan Message, 10)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go client.Subscribe(ctx, msgCh)
|
||||
|
||||
// Should receive only "message" events
|
||||
received := make([]Message, 0)
|
||||
timeout := time.After(1 * time.Second)
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case msg := <-msgCh:
|
||||
received = append(received, msg)
|
||||
if len(received) >= 2 {
|
||||
break loop
|
||||
}
|
||||
case <-timeout:
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
if len(received) != 2 {
|
||||
t.Errorf("received %d messages, want 2", len(received))
|
||||
}
|
||||
|
||||
if len(received) > 0 && received[0].ID != "msg1" {
|
||||
t.Errorf("first message ID = %s, want msg1", received[0].ID)
|
||||
}
|
||||
|
||||
if len(received) > 1 && received[1].ID != "msg2" {
|
||||
t.Errorf("second message ID = %s, want msg2", received[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Subscribe_FilterEvents(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Event: "open"},
|
||||
{Event: "keepalive"},
|
||||
{Event: "message", ID: "actual", Topic: "test"},
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
for _, msg := range messages {
|
||||
data, _ := json.Marshal(msg)
|
||||
fmt.Fprintf(w, "%s\n", data)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, []string{"test"})
|
||||
msgCh := make(chan Message, 10)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
go client.Subscribe(ctx, msgCh)
|
||||
|
||||
select {
|
||||
case msg := <-msgCh:
|
||||
if msg.Event != "message" {
|
||||
t.Errorf("received event = %s, want message", msg.Event)
|
||||
}
|
||||
if msg.ID != "actual" {
|
||||
t.Errorf("message ID = %s, want actual", msg.ID)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("timeout waiting for message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Subscribe_ContextCancellation(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Slow server - keeps connection open
|
||||
time.Sleep(10 * time.Second)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, []string{"test"})
|
||||
msgCh := make(chan Message, 10)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
client.Subscribe(ctx, msgCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Cancel quickly
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
// Should exit promptly
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Subscribe did not exit after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_connect_ServerError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("internal error"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, []string{"test"})
|
||||
msgCh := make(chan Message, 10)
|
||||
|
||||
err := client.connect(context.Background(), msgCh)
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected error for server error response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_connect_URLConstruction(t *testing.T) {
|
||||
var requestedURL string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedURL = r.URL.Path
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, []string{"alerts", "updates"})
|
||||
client.connect(context.Background(), make(chan Message))
|
||||
|
||||
expected := "/alerts,updates/json"
|
||||
if requestedURL != expected {
|
||||
t.Errorf("URL path = %s, want %s", requestedURL, expected)
|
||||
}
|
||||
}
|
||||
58
internal/ntfy/fuzz_test.go
Normal file
58
internal/ntfy/fuzz_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package ntfy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// FuzzParseMessage tests that JSON unmarshaling doesn't panic on arbitrary input
|
||||
func FuzzParseMessage(f *testing.F) {
|
||||
// Seed corpus with valid messages
|
||||
f.Add(`{"event":"message","topic":"test","title":"hello","message":"world"}`)
|
||||
f.Add(`{"event":"message","topic":"alerts","priority":5,"tags":["warning"]}`)
|
||||
f.Add(`{"event":"keepalive"}`)
|
||||
f.Add(`{"event":"open"}`)
|
||||
f.Add(`{}`)
|
||||
f.Add(`{"id":"abc123","time":1706803200,"expires":1706889600}`)
|
||||
f.Add(`{"click":"https://example.com","icon":"https://example.com/icon.png"}`)
|
||||
|
||||
// Edge cases
|
||||
f.Add(``)
|
||||
f.Add(`null`)
|
||||
f.Add(`[]`)
|
||||
f.Add(`"string"`)
|
||||
f.Add(`123`)
|
||||
f.Add(`{"priority":-1}`)
|
||||
f.Add(`{"priority":999999999}`)
|
||||
f.Add(`{"tags":[]}`)
|
||||
f.Add(`{"tags":["","",""]}`)
|
||||
|
||||
f.Fuzz(func(t *testing.T, data string) {
|
||||
var msg Message
|
||||
// Should never panic regardless of input
|
||||
_ = json.Unmarshal([]byte(data), &msg)
|
||||
|
||||
// If it parsed, accessing fields should not panic
|
||||
_ = msg.ID
|
||||
_ = msg.Event
|
||||
_ = msg.Topic
|
||||
_ = msg.Title
|
||||
_ = msg.Message
|
||||
_ = msg.Priority
|
||||
_ = msg.Click
|
||||
_ = len(msg.Tags)
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzParseMessageBytes tests binary input doesn't cause panics
|
||||
func FuzzParseMessageBytes(f *testing.F) {
|
||||
f.Add([]byte(`{"event":"message"}`))
|
||||
f.Add([]byte{0x00})
|
||||
f.Add([]byte{0xff, 0xfe})
|
||||
f.Add([]byte("\xef\xbb\xbf{}")) // BOM + JSON
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
var msg Message
|
||||
_ = json.Unmarshal(data, &msg)
|
||||
})
|
||||
}
|
||||
93
internal/server/server.go
Normal file
93
internal/server/server.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.daviestechlabs.io/daviestechlabs/ntfy-discord/internal/bridge"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// Server provides health and metrics endpoints
|
||||
type Server struct {
|
||||
httpServer *http.Server
|
||||
bridge *bridge.Bridge
|
||||
}
|
||||
|
||||
// New creates a new HTTP server
|
||||
func New(port string, b *bridge.Bridge) *Server {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
s := &Server{
|
||||
bridge: b,
|
||||
httpServer: &http.Server{
|
||||
Addr: ":" + port,
|
||||
Handler: mux,
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
mux.HandleFunc("/health", s.handleHealth)
|
||||
mux.HandleFunc("/ready", s.handleReady)
|
||||
mux.Handle("/metrics", promhttp.Handler())
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start begins serving HTTP requests
|
||||
func (s *Server) Start() {
|
||||
slog.Info("starting HTTP server", "addr", s.httpServer.Addr)
|
||||
if err := s.httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
||||
slog.Error("HTTP server error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the server
|
||||
func (s *Server) Shutdown(ctx context.Context) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.httpServer.Shutdown(ctx); err != nil {
|
||||
slog.Error("HTTP server shutdown error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
status := struct {
|
||||
Status string `json:"status"`
|
||||
Healthy bool `json:"healthy"`
|
||||
}{
|
||||
Status: "ok",
|
||||
Healthy: s.bridge.IsHealthy(),
|
||||
}
|
||||
|
||||
if !status.Healthy {
|
||||
status.Status = "unhealthy"
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
|
||||
func (s *Server) handleReady(w http.ResponseWriter, r *http.Request) {
|
||||
status := struct {
|
||||
Status string `json:"status"`
|
||||
Ready bool `json:"ready"`
|
||||
}{
|
||||
Status: "ok",
|
||||
Ready: s.bridge.IsReady(),
|
||||
}
|
||||
|
||||
if !status.Ready {
|
||||
status.Status = "not ready"
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
91
internal/server/server_test.go
Normal file
91
internal/server/server_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServer_HealthEndpoint_StatusCodes(t *testing.T) {
|
||||
// Test health endpoint returns JSON
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"status":"ok","healthy":true}`))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
if w.Header().Get("Content-Type") != "application/json" {
|
||||
t.Error("expected Content-Type application/json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ReadyEndpoint_StatusCodes(t *testing.T) {
|
||||
// Test ready endpoint returns JSON
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ready", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"status":"ready","ready":true}`))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ready", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Shutdown(t *testing.T) {
|
||||
// Create a minimal server for shutdown testing
|
||||
srv := &http.Server{
|
||||
Addr: ":0",
|
||||
Handler: http.NewServeMux(),
|
||||
}
|
||||
|
||||
// Start in background
|
||||
go srv.ListenAndServe()
|
||||
|
||||
// Give it a moment to start
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Shutdown should complete without error
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Shutdown() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_MetricsEndpoint(t *testing.T) {
|
||||
// Verify /metrics endpoint can be created
|
||||
// The actual promhttp.Handler() is tested by Prometheus library
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("# metrics here"))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
mux.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
215
internal/vault/client.go
Normal file
215
internal/vault/client.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/api/auth/kubernetes"
|
||||
)
|
||||
|
||||
// Client wraps the Vault API client with auto-renewal
|
||||
type Client struct {
|
||||
client *api.Client
|
||||
mountPath string
|
||||
secretPath string
|
||||
|
||||
mu sync.RWMutex
|
||||
secret *api.Secret
|
||||
}
|
||||
|
||||
// Config holds Vault client configuration
|
||||
type Config struct {
|
||||
// Address is the Vault server address (e.g., http://vault.vault.svc.cluster.local:8200)
|
||||
Address string
|
||||
// AuthMethod is either "kubernetes" or "token"
|
||||
AuthMethod string
|
||||
// Role is the Vault role for Kubernetes auth
|
||||
Role string
|
||||
// MountPath is the secrets engine mount (e.g., "secret")
|
||||
MountPath string
|
||||
// SecretPath is the path within the mount (e.g., "data/ntfy-discord")
|
||||
SecretPath string
|
||||
// TokenPath is the path to the Kubernetes service account token
|
||||
TokenPath string
|
||||
}
|
||||
|
||||
// NewClient creates a new Vault client
|
||||
func NewClient(cfg Config) (*Client, error) {
|
||||
vaultCfg := api.DefaultConfig()
|
||||
vaultCfg.Address = cfg.Address
|
||||
|
||||
client, err := api.NewClient(vaultCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create vault client: %w", err)
|
||||
}
|
||||
|
||||
vc := &Client{
|
||||
client: client,
|
||||
mountPath: cfg.MountPath,
|
||||
secretPath: cfg.SecretPath,
|
||||
}
|
||||
|
||||
// Authenticate based on method
|
||||
switch strings.ToLower(cfg.AuthMethod) {
|
||||
case "kubernetes":
|
||||
if err := vc.authKubernetes(cfg.Role, cfg.TokenPath); err != nil {
|
||||
return nil, fmt.Errorf("kubernetes auth failed: %w", err)
|
||||
}
|
||||
case "token":
|
||||
// Token should be set via VAULT_TOKEN env var, which the API client reads automatically
|
||||
if client.Token() == "" {
|
||||
return nil, fmt.Errorf("VAULT_TOKEN environment variable not set")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported auth method: %s", cfg.AuthMethod)
|
||||
}
|
||||
|
||||
slog.Info("vault client authenticated",
|
||||
"address", cfg.Address,
|
||||
"auth_method", cfg.AuthMethod,
|
||||
"mount_path", cfg.MountPath,
|
||||
"secret_path", cfg.SecretPath,
|
||||
)
|
||||
|
||||
return vc, nil
|
||||
}
|
||||
|
||||
// authKubernetes authenticates using Kubernetes service account
|
||||
func (c *Client) authKubernetes(role, tokenPath string) error {
|
||||
if tokenPath == "" {
|
||||
tokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token"
|
||||
}
|
||||
|
||||
// Verify token file exists
|
||||
if _, err := os.Stat(tokenPath); err != nil {
|
||||
return fmt.Errorf("service account token not found at %s: %w", tokenPath, err)
|
||||
}
|
||||
|
||||
k8sAuth, err := kubernetes.NewKubernetesAuth(
|
||||
role,
|
||||
kubernetes.WithServiceAccountTokenPath(tokenPath),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create kubernetes auth: %w", err)
|
||||
}
|
||||
|
||||
authInfo, err := c.client.Auth().Login(context.Background(), k8sAuth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to login with kubernetes auth: %w", err)
|
||||
}
|
||||
|
||||
if authInfo == nil {
|
||||
return fmt.Errorf("no auth info returned from vault")
|
||||
}
|
||||
|
||||
slog.Info("authenticated to vault via kubernetes",
|
||||
"role", role,
|
||||
"token_ttl", authInfo.Auth.LeaseDuration,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSecret retrieves a secret value by key
|
||||
func (c *Client) GetSecret(ctx context.Context, key string) (string, error) {
|
||||
c.mu.RLock()
|
||||
secret := c.secret
|
||||
c.mu.RUnlock()
|
||||
|
||||
// Fetch secret if not cached
|
||||
if secret == nil {
|
||||
if err := c.refreshSecret(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
c.mu.RLock()
|
||||
secret = c.secret
|
||||
c.mu.RUnlock()
|
||||
}
|
||||
|
||||
// KV v2 stores data under "data" key
|
||||
data, ok := secret.Data["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
// Try KV v1 format
|
||||
data = secret.Data
|
||||
}
|
||||
|
||||
value, ok := data[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("key %q not found in secret", key)
|
||||
}
|
||||
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("key %q is not a string", key)
|
||||
}
|
||||
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// refreshSecret fetches the secret from Vault
|
||||
func (c *Client) refreshSecret(ctx context.Context) error {
|
||||
// Construct full path for KV v2: mount/data/path
|
||||
fullPath := fmt.Sprintf("%s/data/%s", c.mountPath, c.secretPath)
|
||||
|
||||
secret, err := c.client.Logical().ReadWithContext(ctx, fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read secret: %w", err)
|
||||
}
|
||||
|
||||
if secret == nil {
|
||||
return fmt.Errorf("secret not found at %s", fullPath)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.secret = secret
|
||||
c.mu.Unlock()
|
||||
|
||||
slog.Debug("refreshed secret from vault", "path", fullPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// WatchAndRefresh periodically refreshes the secret and renews the auth token
|
||||
func (c *Client) WatchAndRefresh(ctx context.Context, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := c.refreshSecret(ctx); err != nil {
|
||||
slog.Error("failed to refresh secret from vault", "error", err)
|
||||
} else {
|
||||
slog.Debug("vault secret refreshed")
|
||||
}
|
||||
|
||||
// Renew token if renewable
|
||||
if c.client.Token() != "" {
|
||||
if _, err := c.client.Auth().Token().RenewSelf(0); err != nil {
|
||||
slog.Warn("failed to renew vault token", "error", err)
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close cleans up the client
|
||||
func (c *Client) Close() error {
|
||||
if c == nil || c.client == nil {
|
||||
return nil
|
||||
}
|
||||
// Revoke token on shutdown (optional, but good practice)
|
||||
if c.client.Token() != "" {
|
||||
if err := c.client.Auth().Token().RevokeSelf(""); err != nil {
|
||||
slog.Warn("failed to revoke vault token", "error", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
94
internal/vault/client_test.go
Normal file
94
internal/vault/client_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConfig_Defaults(t *testing.T) {
|
||||
cfg := Config{
|
||||
Address: "http://vault.vault.svc.cluster.local:8200",
|
||||
AuthMethod: "kubernetes",
|
||||
Role: "ntfy-discord",
|
||||
MountPath: "secret",
|
||||
SecretPath: "ntfy-discord",
|
||||
}
|
||||
|
||||
if cfg.Address == "" {
|
||||
t.Error("Address should not be empty")
|
||||
}
|
||||
|
||||
if cfg.AuthMethod != "kubernetes" {
|
||||
t.Errorf("AuthMethod = %s, want kubernetes", cfg.AuthMethod)
|
||||
}
|
||||
|
||||
if cfg.MountPath != "secret" {
|
||||
t.Errorf("MountPath = %s, want secret", cfg.MountPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_InvalidAddress(t *testing.T) {
|
||||
cfg := Config{
|
||||
Address: "not-a-valid-url",
|
||||
AuthMethod: "kubernetes",
|
||||
Role: "test",
|
||||
MountPath: "secret",
|
||||
SecretPath: "test",
|
||||
}
|
||||
|
||||
// This should fail because there's no Kubernetes token
|
||||
_, err := NewClient(cfg)
|
||||
if err == nil {
|
||||
t.Error("expected error for kubernetes auth without token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_TokenAuth_NoToken(t *testing.T) {
|
||||
cfg := Config{
|
||||
Address: "http://localhost:8200",
|
||||
AuthMethod: "token",
|
||||
MountPath: "secret",
|
||||
SecretPath: "test",
|
||||
}
|
||||
|
||||
// Should fail because VAULT_TOKEN is not set
|
||||
_, err := NewClient(cfg)
|
||||
if err == nil {
|
||||
t.Error("expected error for token auth without VAULT_TOKEN")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_UnsupportedAuthMethod(t *testing.T) {
|
||||
cfg := Config{
|
||||
Address: "http://localhost:8200",
|
||||
AuthMethod: "unsupported",
|
||||
MountPath: "secret",
|
||||
SecretPath: "test",
|
||||
}
|
||||
|
||||
_, err := NewClient(cfg)
|
||||
if err == nil {
|
||||
t.Error("expected error for unsupported auth method")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Close_Nil(t *testing.T) {
|
||||
// Test that Close doesn't panic on a partially initialized client
|
||||
c := &Client{}
|
||||
|
||||
// Should not panic
|
||||
err := c.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_TokenPathDefault(t *testing.T) {
|
||||
cfg := Config{
|
||||
TokenPath: "",
|
||||
}
|
||||
|
||||
// Default token path should be empty (will be set in authKubernetes)
|
||||
if cfg.TokenPath != "" {
|
||||
t.Errorf("TokenPath = %s, want empty", cfg.TokenPath)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user