initial working version
This commit is contained in:
parent
db6b90134d
commit
d986a0b31a
19 changed files with 1430 additions and 0 deletions
64
internal/config/config.go
Normal file
64
internal/config/config.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/adrg/xdg"
|
||||
toml "github.com/pelletier/go-toml/v2"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Issuer string `toml:"issuer"`
|
||||
ClientID string `toml:"client_id"`
|
||||
ClientSecret string `toml:"client_secret,omitempty"`
|
||||
CacheHost string `toml:"cache_host"`
|
||||
NetrcPath string `toml:"netrc_path"`
|
||||
}
|
||||
|
||||
// Load reads the config from the given path, or from the default XDG location.
|
||||
func Load(path string) (*Config, error) {
|
||||
if path == "" {
|
||||
path = filepath.Join(xdg.ConfigHome, "nix-cache-login", "config.toml")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading config file: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := toml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parsing config file: %w", err)
|
||||
}
|
||||
|
||||
cfg.NetrcPath = os.ExpandEnv(cfg.NetrcPath)
|
||||
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) validate() error {
|
||||
if c.Issuer == "" {
|
||||
return fmt.Errorf("config: issuer is required")
|
||||
}
|
||||
if c.ClientID == "" {
|
||||
return fmt.Errorf("config: client_id is required")
|
||||
}
|
||||
if c.CacheHost == "" {
|
||||
return fmt.Errorf("config: cache_host is required")
|
||||
}
|
||||
if c.NetrcPath == "" {
|
||||
return fmt.Errorf("config: netrc_path is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshTokenPath returns the path to the stored refresh token.
|
||||
func RefreshTokenPath() string {
|
||||
return filepath.Join(xdg.ConfigHome, "nix-cache-login", "refresh-token")
|
||||
}
|
||||
177
internal/config/config_test.go
Normal file
177
internal/config/config_test.go
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadValidConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgFile := filepath.Join(dir, "config.toml")
|
||||
|
||||
content := `
|
||||
issuer = "https://id.example.com/realms/test"
|
||||
client_id = "nix-cache"
|
||||
cache_host = "cache.example.com"
|
||||
netrc_path = "/home/user/.config/nix/netrc"
|
||||
`
|
||||
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(cfgFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Issuer != "https://id.example.com/realms/test" {
|
||||
t.Errorf("issuer = %q, want %q", cfg.Issuer, "https://id.example.com/realms/test")
|
||||
}
|
||||
if cfg.ClientID != "nix-cache" {
|
||||
t.Errorf("client_id = %q, want %q", cfg.ClientID, "nix-cache")
|
||||
}
|
||||
if cfg.CacheHost != "cache.example.com" {
|
||||
t.Errorf("cache_host = %q, want %q", cfg.CacheHost, "cache.example.com")
|
||||
}
|
||||
if cfg.ClientSecret != "" {
|
||||
t.Errorf("client_secret = %q, want empty", cfg.ClientSecret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigWithClientSecret(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgFile := filepath.Join(dir, "config.toml")
|
||||
|
||||
content := `
|
||||
issuer = "https://id.example.com/realms/test"
|
||||
client_id = "nix-cache-server"
|
||||
client_secret = "super-secret"
|
||||
cache_host = "cache.example.com"
|
||||
netrc_path = "/tmp/netrc"
|
||||
`
|
||||
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(cfgFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.ClientSecret != "super-secret" {
|
||||
t.Errorf("client_secret = %q, want %q", cfg.ClientSecret, "super-secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVarExpansionInNetrcPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgFile := filepath.Join(dir, "config.toml")
|
||||
|
||||
t.Setenv("TEST_CONFIG_DIR", "/custom/config")
|
||||
|
||||
content := `
|
||||
issuer = "https://id.example.com/realms/test"
|
||||
client_id = "nix-cache"
|
||||
cache_host = "cache.example.com"
|
||||
netrc_path = "$TEST_CONFIG_DIR/nix/netrc"
|
||||
`
|
||||
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(cfgFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.NetrcPath != "/custom/config/nix/netrc" {
|
||||
t.Errorf("netrc_path = %q, want %q", cfg.NetrcPath, "/custom/config/nix/netrc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVarExpansionBraces(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgFile := filepath.Join(dir, "config.toml")
|
||||
|
||||
t.Setenv("MY_HOME", "/home/testuser")
|
||||
|
||||
content := `
|
||||
issuer = "https://id.example.com/realms/test"
|
||||
client_id = "nix-cache"
|
||||
cache_host = "cache.example.com"
|
||||
netrc_path = "${MY_HOME}/.config/nix/netrc"
|
||||
`
|
||||
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(cfgFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.NetrcPath != "/home/testuser/.config/nix/netrc" {
|
||||
t.Errorf("netrc_path = %q, want %q", cfg.NetrcPath, "/home/testuser/.config/nix/netrc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingRequiredFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "missing issuer",
|
||||
content: `client_id = "x"` + "\n" + `cache_host = "x"` + "\n" + `netrc_path = "/tmp/x"`,
|
||||
errMsg: "issuer is required",
|
||||
},
|
||||
{
|
||||
name: "missing client_id",
|
||||
content: `issuer = "https://x"` + "\n" + `cache_host = "x"` + "\n" + `netrc_path = "/tmp/x"`,
|
||||
errMsg: "client_id is required",
|
||||
},
|
||||
{
|
||||
name: "missing cache_host",
|
||||
content: `issuer = "https://x"` + "\n" + `client_id = "x"` + "\n" + `netrc_path = "/tmp/x"`,
|
||||
errMsg: "cache_host is required",
|
||||
},
|
||||
{
|
||||
name: "missing netrc_path",
|
||||
content: `issuer = "https://x"` + "\n" + `client_id = "x"` + "\n" + `cache_host = "x"`,
|
||||
errMsg: "netrc_path is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgFile := filepath.Join(dir, "config.toml")
|
||||
if err := os.WriteFile(cfgFile, []byte(tt.content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(cfgFile)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("error = %q, want to contain %q", err.Error(), tt.errMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchString(s, substr)
|
||||
}
|
||||
|
||||
func searchString(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue