initial working version

This commit is contained in:
Abel Luck 2026-02-26 11:05:16 +01:00
parent db6b90134d
commit d986a0b31a
19 changed files with 1430 additions and 0 deletions

View file

@ -0,0 +1,27 @@
package browser
import (
"fmt"
"os/exec"
"runtime"
)
// Open opens the given URL in the user's default browser.
func Open(url string) error {
var cmd *exec.Cmd
switch runtime.GOOS {
case "linux":
cmd = exec.Command("xdg-open", url)
case "darwin":
cmd = exec.Command("open", url)
default:
return fmt.Errorf("unsupported platform %s; open this URL manually:\n %s", runtime.GOOS, url)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to open browser: %w\n Open this URL manually:\n %s", err, url)
}
return nil
}

64
internal/config/config.go Normal file
View file

@ -0,0 +1,64 @@
package config
import (
"fmt"
"os"
"path/filepath"
"github.com/adrg/xdg"
toml "github.com/pelletier/go-toml/v2"
)
type Config struct {
Issuer string `toml:"issuer"`
ClientID string `toml:"client_id"`
ClientSecret string `toml:"client_secret,omitempty"`
CacheHost string `toml:"cache_host"`
NetrcPath string `toml:"netrc_path"`
}
// Load reads the config from the given path, or from the default XDG location.
func Load(path string) (*Config, error) {
if path == "" {
path = filepath.Join(xdg.ConfigHome, "nix-cache-login", "config.toml")
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading config file: %w", err)
}
var cfg Config
if err := toml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parsing config file: %w", err)
}
cfg.NetrcPath = os.ExpandEnv(cfg.NetrcPath)
if err := cfg.validate(); err != nil {
return nil, err
}
return &cfg, nil
}
func (c *Config) validate() error {
if c.Issuer == "" {
return fmt.Errorf("config: issuer is required")
}
if c.ClientID == "" {
return fmt.Errorf("config: client_id is required")
}
if c.CacheHost == "" {
return fmt.Errorf("config: cache_host is required")
}
if c.NetrcPath == "" {
return fmt.Errorf("config: netrc_path is required")
}
return nil
}
// RefreshTokenPath returns the path to the stored refresh token.
func RefreshTokenPath() string {
return filepath.Join(xdg.ConfigHome, "nix-cache-login", "refresh-token")
}

View file

@ -0,0 +1,177 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadValidConfig(t *testing.T) {
dir := t.TempDir()
cfgFile := filepath.Join(dir, "config.toml")
content := `
issuer = "https://id.example.com/realms/test"
client_id = "nix-cache"
cache_host = "cache.example.com"
netrc_path = "/home/user/.config/nix/netrc"
`
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(cfgFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Issuer != "https://id.example.com/realms/test" {
t.Errorf("issuer = %q, want %q", cfg.Issuer, "https://id.example.com/realms/test")
}
if cfg.ClientID != "nix-cache" {
t.Errorf("client_id = %q, want %q", cfg.ClientID, "nix-cache")
}
if cfg.CacheHost != "cache.example.com" {
t.Errorf("cache_host = %q, want %q", cfg.CacheHost, "cache.example.com")
}
if cfg.ClientSecret != "" {
t.Errorf("client_secret = %q, want empty", cfg.ClientSecret)
}
}
func TestLoadConfigWithClientSecret(t *testing.T) {
dir := t.TempDir()
cfgFile := filepath.Join(dir, "config.toml")
content := `
issuer = "https://id.example.com/realms/test"
client_id = "nix-cache-server"
client_secret = "super-secret"
cache_host = "cache.example.com"
netrc_path = "/tmp/netrc"
`
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(cfgFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.ClientSecret != "super-secret" {
t.Errorf("client_secret = %q, want %q", cfg.ClientSecret, "super-secret")
}
}
func TestEnvVarExpansionInNetrcPath(t *testing.T) {
dir := t.TempDir()
cfgFile := filepath.Join(dir, "config.toml")
t.Setenv("TEST_CONFIG_DIR", "/custom/config")
content := `
issuer = "https://id.example.com/realms/test"
client_id = "nix-cache"
cache_host = "cache.example.com"
netrc_path = "$TEST_CONFIG_DIR/nix/netrc"
`
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(cfgFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.NetrcPath != "/custom/config/nix/netrc" {
t.Errorf("netrc_path = %q, want %q", cfg.NetrcPath, "/custom/config/nix/netrc")
}
}
func TestEnvVarExpansionBraces(t *testing.T) {
dir := t.TempDir()
cfgFile := filepath.Join(dir, "config.toml")
t.Setenv("MY_HOME", "/home/testuser")
content := `
issuer = "https://id.example.com/realms/test"
client_id = "nix-cache"
cache_host = "cache.example.com"
netrc_path = "${MY_HOME}/.config/nix/netrc"
`
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
t.Fatal(err)
}
cfg, err := Load(cfgFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.NetrcPath != "/home/testuser/.config/nix/netrc" {
t.Errorf("netrc_path = %q, want %q", cfg.NetrcPath, "/home/testuser/.config/nix/netrc")
}
}
func TestMissingRequiredFields(t *testing.T) {
tests := []struct {
name string
content string
errMsg string
}{
{
name: "missing issuer",
content: `client_id = "x"` + "\n" + `cache_host = "x"` + "\n" + `netrc_path = "/tmp/x"`,
errMsg: "issuer is required",
},
{
name: "missing client_id",
content: `issuer = "https://x"` + "\n" + `cache_host = "x"` + "\n" + `netrc_path = "/tmp/x"`,
errMsg: "client_id is required",
},
{
name: "missing cache_host",
content: `issuer = "https://x"` + "\n" + `client_id = "x"` + "\n" + `netrc_path = "/tmp/x"`,
errMsg: "cache_host is required",
},
{
name: "missing netrc_path",
content: `issuer = "https://x"` + "\n" + `client_id = "x"` + "\n" + `cache_host = "x"`,
errMsg: "netrc_path is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
cfgFile := filepath.Join(dir, "config.toml")
if err := os.WriteFile(cfgFile, []byte(tt.content), 0644); err != nil {
t.Fatal(err)
}
_, err := Load(cfgFile)
if err == nil {
t.Fatal("expected error, got nil")
}
if !contains(err.Error(), tt.errMsg) {
t.Errorf("error = %q, want to contain %q", err.Error(), tt.errMsg)
}
})
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && searchString(s, substr)
}
func searchString(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

136
internal/netrc/netrc.go Normal file
View file

@ -0,0 +1,136 @@
package netrc
import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"
)
// Upsert updates or inserts a machine entry in the netrc file.
// Only the password field is written (Nix uses password from netrc as auth).
func Upsert(path, machine, password string) error {
entries, err := parse(path)
if err != nil && !os.IsNotExist(err) {
return err
}
found := false
for i, e := range entries {
if e.machine == machine {
entries[i].password = password
found = true
break
}
}
if !found {
entries = append(entries, entry{machine: machine, password: password})
}
return write(path, entries)
}
// Remove removes the entry for the given machine from the netrc file.
func Remove(path, machine string) error {
entries, err := parse(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
var filtered []entry
for _, e := range entries {
if e.machine != machine {
filtered = append(filtered, e)
}
}
return write(path, filtered)
}
// GetPassword returns the password for the given machine, or empty string if not found.
func GetPassword(path, machine string) (string, error) {
entries, err := parse(path)
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", err
}
for _, e := range entries {
if e.machine == machine {
return e.password, nil
}
}
return "", nil
}
type entry struct {
machine string
password string
}
func parse(path string) ([]entry, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var entries []entry
var current *entry
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
switch fields[0] {
case "machine":
if current != nil {
entries = append(entries, *current)
}
current = &entry{machine: fields[1]}
case "password":
if current != nil {
current.password = fields[1]
}
}
}
if current != nil {
entries = append(entries, *current)
}
if err := scanner.Err(); err != nil {
return nil, err
}
return entries, nil
}
func write(path string, entries []entry) error {
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return fmt.Errorf("creating directory for %s: %w", path, err)
}
var b strings.Builder
for i, e := range entries {
if i > 0 {
b.WriteString("\n")
}
fmt.Fprintf(&b, "machine %s\npassword %s\n", e.machine, e.password)
}
return os.WriteFile(path, []byte(b.String()), 0600)
}

View file

@ -0,0 +1,175 @@
package netrc
import (
"os"
"path/filepath"
"testing"
)
func TestUpsertEmptyFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "netrc")
if err := Upsert(path, "cache.example.com", "token123"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
pw, err := GetPassword(path, "cache.example.com")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pw != "token123" {
t.Errorf("password = %q, want %q", pw, "token123")
}
}
func TestUpsertUpdateExisting(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "netrc")
initial := "machine other.host\npassword otherpass\n\nmachine cache.example.com\npassword oldtoken\n"
if err := os.WriteFile(path, []byte(initial), 0600); err != nil {
t.Fatal(err)
}
if err := Upsert(path, "cache.example.com", "newtoken"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check updated entry
pw, err := GetPassword(path, "cache.example.com")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pw != "newtoken" {
t.Errorf("password = %q, want %q", pw, "newtoken")
}
// Check other entry preserved
pw, err = GetPassword(path, "other.host")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pw != "otherpass" {
t.Errorf("other password = %q, want %q", pw, "otherpass")
}
}
func TestUpsertAppend(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "netrc")
initial := "machine existing.host\npassword existingpass\n"
if err := os.WriteFile(path, []byte(initial), 0600); err != nil {
t.Fatal(err)
}
if err := Upsert(path, "cache.example.com", "newtoken"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
pw, err := GetPassword(path, "cache.example.com")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pw != "newtoken" {
t.Errorf("password = %q, want %q", pw, "newtoken")
}
pw, err = GetPassword(path, "existing.host")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pw != "existingpass" {
t.Errorf("existing password = %q, want %q", pw, "existingpass")
}
}
func TestRemove(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "netrc")
initial := "machine keep.host\npassword keeppass\n\nmachine remove.host\npassword removepass\n"
if err := os.WriteFile(path, []byte(initial), 0600); err != nil {
t.Fatal(err)
}
if err := Remove(path, "remove.host"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
pw, err := GetPassword(path, "remove.host")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pw != "" {
t.Errorf("removed entry still has password = %q", pw)
}
pw, err = GetPassword(path, "keep.host")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pw != "keeppass" {
t.Errorf("kept password = %q, want %q", pw, "keeppass")
}
}
func TestRemoveNonexistentFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "nonexistent")
if err := Remove(path, "anything"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestGetPasswordNoFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "nonexistent")
pw, err := GetPassword(path, "anything")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pw != "" {
t.Errorf("password = %q, want empty", pw)
}
}
func TestGetPasswordNotFound(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "netrc")
content := "machine other.host\npassword otherpass\n"
if err := os.WriteFile(path, []byte(content), 0600); err != nil {
t.Fatal(err)
}
pw, err := GetPassword(path, "missing.host")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pw != "" {
t.Errorf("password = %q, want empty", pw)
}
}
func TestFilePermissions(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "netrc")
if err := Upsert(path, "cache.example.com", "token"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
info, err := os.Stat(path)
if err != nil {
t.Fatalf("stat error: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("file permissions = %o, want 0600", perm)
}
}

25
internal/pkce/pkce.go Normal file
View file

@ -0,0 +1,25 @@
package pkce
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
)
const verifierLength = 43
// Generate creates a PKCE code verifier and its S256 challenge.
func Generate() (verifier, challenge string, err error) {
// Generate random bytes and encode to URL-safe base64 (no padding)
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", "", err
}
verifier = base64.RawURLEncoding.EncodeToString(buf)
// Derive challenge: base64url(sha256(verifier))
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge, nil
}

View file

@ -0,0 +1,63 @@
package pkce
import (
"crypto/sha256"
"encoding/base64"
"regexp"
"testing"
)
func TestGenerateVerifierLength(t *testing.T) {
verifier, _, err := Generate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(verifier) < 43 || len(verifier) > 128 {
t.Errorf("verifier length = %d, want 43-128", len(verifier))
}
}
func TestGenerateVerifierCharacterSet(t *testing.T) {
verifier, _, err := Generate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// RFC 7636: unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
valid := regexp.MustCompile(`^[A-Za-z0-9\-._~]+$`)
if !valid.MatchString(verifier) {
t.Errorf("verifier contains invalid characters: %q", verifier)
}
}
func TestGenerateChallengeCorrectness(t *testing.T) {
verifier, challenge, err := Generate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Recompute challenge from verifier
h := sha256.Sum256([]byte(verifier))
expected := base64.RawURLEncoding.EncodeToString(h[:])
if challenge != expected {
t.Errorf("challenge = %q, want %q", challenge, expected)
}
}
func TestGenerateUniqueness(t *testing.T) {
v1, _, err := Generate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
v2, _, err := Generate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if v1 == v2 {
t.Error("two Generate() calls produced identical verifiers")
}
}

46
internal/token/jwt.go Normal file
View file

@ -0,0 +1,46 @@
package token
import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
)
// DecodePayload decodes the payload (claims) of a JWT without verifying the signature.
func DecodePayload(tokenStr string) (map[string]interface{}, error) {
parts := strings.Split(tokenStr, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT: expected 3 parts, got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("decoding JWT payload: %w", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, fmt.Errorf("unmarshaling JWT payload: %w", err)
}
return claims, nil
}
// ExpiryInfo extracts the expiry time and remaining duration from JWT claims.
func ExpiryInfo(claims map[string]interface{}) (exp time.Time, remaining time.Duration) {
expVal, ok := claims["exp"]
if !ok {
return time.Time{}, 0
}
expFloat, ok := expVal.(float64)
if !ok {
return time.Time{}, 0
}
exp = time.Unix(int64(expFloat), 0)
remaining = time.Until(exp)
return exp, remaining
}

View file

@ -0,0 +1,92 @@
package token
import (
"encoding/base64"
"encoding/json"
"testing"
"time"
)
// buildTestJWT builds a JWT string with the given payload (no real signature).
func buildTestJWT(claims map[string]interface{}) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
payload, _ := json.Marshal(claims)
payloadEnc := base64.RawURLEncoding.EncodeToString(payload)
sig := base64.RawURLEncoding.EncodeToString([]byte("fakesig"))
return header + "." + payloadEnc + "." + sig
}
func TestDecodePayload(t *testing.T) {
claims := map[string]interface{}{
"iss": "https://id.example.com/realms/test",
"sub": "user123",
"exp": float64(1700000000),
"aud": "nix-cache",
}
jwt := buildTestJWT(claims)
decoded, err := DecodePayload(jwt)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded["iss"] != "https://id.example.com/realms/test" {
t.Errorf("iss = %v, want %v", decoded["iss"], "https://id.example.com/realms/test")
}
if decoded["sub"] != "user123" {
t.Errorf("sub = %v, want %v", decoded["sub"], "user123")
}
if decoded["aud"] != "nix-cache" {
t.Errorf("aud = %v, want %v", decoded["aud"], "nix-cache")
}
}
func TestExpiryInfo(t *testing.T) {
futureExp := time.Now().Add(1 * time.Hour).Unix()
claims := map[string]interface{}{
"exp": float64(futureExp),
}
exp, remaining := ExpiryInfo(claims)
if exp.Unix() != futureExp {
t.Errorf("exp = %v, want %v", exp.Unix(), futureExp)
}
if remaining < 59*time.Minute || remaining > 61*time.Minute {
t.Errorf("remaining = %v, expected ~1 hour", remaining)
}
}
func TestExpiryInfoPast(t *testing.T) {
pastExp := time.Now().Add(-1 * time.Hour).Unix()
claims := map[string]interface{}{
"exp": float64(pastExp),
}
_, remaining := ExpiryInfo(claims)
if remaining >= 0 {
t.Errorf("remaining = %v, expected negative (expired)", remaining)
}
}
func TestDecodePayloadMalformed(t *testing.T) {
tests := []struct {
name string
token string
}{
{"no dots", "nodots"},
{"one dot", "one.dot"},
{"empty payload", "header..sig"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := DecodePayload(tt.token)
if err == nil {
t.Error("expected error, got nil")
}
})
}
}