nix-cache-login/internal/config/config_test.go

177 lines
4.3 KiB
Go

package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadValidConfig(t *testing.T) {
dir := t.TempDir()
cfgFile := filepath.Join(dir, "config.toml")
content := `
issuer = "https://id.example.com/realms/test"
client_id = "nix-cache"
cache_host = "cache.example.com"
netrc_path = "/home/user/.config/nix/netrc"
`
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(cfgFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Issuer != "https://id.example.com/realms/test" {
t.Errorf("issuer = %q, want %q", cfg.Issuer, "https://id.example.com/realms/test")
}
if cfg.ClientID != "nix-cache" {
t.Errorf("client_id = %q, want %q", cfg.ClientID, "nix-cache")
}
if cfg.CacheHost != "cache.example.com" {
t.Errorf("cache_host = %q, want %q", cfg.CacheHost, "cache.example.com")
}
if cfg.ClientSecret != "" {
t.Errorf("client_secret = %q, want empty", cfg.ClientSecret)
}
}
func TestLoadConfigWithClientSecret(t *testing.T) {
dir := t.TempDir()
cfgFile := filepath.Join(dir, "config.toml")
content := `
issuer = "https://id.example.com/realms/test"
client_id = "nix-cache-server"
client_secret = "super-secret"
cache_host = "cache.example.com"
netrc_path = "/tmp/netrc"
`
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(cfgFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.ClientSecret != "super-secret" {
t.Errorf("client_secret = %q, want %q", cfg.ClientSecret, "super-secret")
}
}
func TestEnvVarExpansionInNetrcPath(t *testing.T) {
dir := t.TempDir()
cfgFile := filepath.Join(dir, "config.toml")
t.Setenv("TEST_CONFIG_DIR", "/custom/config")
content := `
issuer = "https://id.example.com/realms/test"
client_id = "nix-cache"
cache_host = "cache.example.com"
netrc_path = "$TEST_CONFIG_DIR/nix/netrc"
`
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(cfgFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.NetrcPath != "/custom/config/nix/netrc" {
t.Errorf("netrc_path = %q, want %q", cfg.NetrcPath, "/custom/config/nix/netrc")
}
}
func TestEnvVarExpansionBraces(t *testing.T) {
dir := t.TempDir()
cfgFile := filepath.Join(dir, "config.toml")
t.Setenv("MY_HOME", "/home/testuser")
content := `
issuer = "https://id.example.com/realms/test"
client_id = "nix-cache"
cache_host = "cache.example.com"
netrc_path = "${MY_HOME}/.config/nix/netrc"
`
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(cfgFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.NetrcPath != "/home/testuser/.config/nix/netrc" {
t.Errorf("netrc_path = %q, want %q", cfg.NetrcPath, "/home/testuser/.config/nix/netrc")
}
}
func TestMissingRequiredFields(t *testing.T) {
tests := []struct {
name string
content string
errMsg string
}{
{
name: "missing issuer",
content: `client_id = "x"` + "\n" + `cache_host = "x"` + "\n" + `netrc_path = "/tmp/x"`,
errMsg: "issuer is required",
},
{
name: "missing client_id",
content: `issuer = "https://x"` + "\n" + `cache_host = "x"` + "\n" + `netrc_path = "/tmp/x"`,
errMsg: "client_id is required",
},
{
name: "missing cache_host",
content: `issuer = "https://x"` + "\n" + `client_id = "x"` + "\n" + `netrc_path = "/tmp/x"`,
errMsg: "cache_host is required",
},
{
name: "missing netrc_path",
content: `issuer = "https://x"` + "\n" + `client_id = "x"` + "\n" + `cache_host = "x"`,
errMsg: "netrc_path is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
cfgFile := filepath.Join(dir, "config.toml")
if err := os.WriteFile(cfgFile, []byte(tt.content), 0644); err != nil {
t.Fatal(err)
}
_, err := Load(cfgFile)
if err == nil {
t.Fatal("expected error, got nil")
}
if !contains(err.Error(), tt.errMsg) {
t.Errorf("error = %q, want to contain %q", err.Error(), tt.errMsg)
}
})
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && searchString(s, substr)
}
func searchString(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}