initial working version
This commit is contained in:
parent
db6b90134d
commit
d986a0b31a
19 changed files with 1430 additions and 0 deletions
27
internal/browser/browser.go
Normal file
27
internal/browser/browser.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package browser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// Open opens the given URL in the user's default browser.
|
||||
func Open(url string) error {
|
||||
var cmd *exec.Cmd
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
cmd = exec.Command("xdg-open", url)
|
||||
case "darwin":
|
||||
cmd = exec.Command("open", url)
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform %s; open this URL manually:\n %s", runtime.GOOS, url)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to open browser: %w\n Open this URL manually:\n %s", err, url)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
64
internal/config/config.go
Normal file
64
internal/config/config.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/adrg/xdg"
|
||||
toml "github.com/pelletier/go-toml/v2"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Issuer string `toml:"issuer"`
|
||||
ClientID string `toml:"client_id"`
|
||||
ClientSecret string `toml:"client_secret,omitempty"`
|
||||
CacheHost string `toml:"cache_host"`
|
||||
NetrcPath string `toml:"netrc_path"`
|
||||
}
|
||||
|
||||
// Load reads the config from the given path, or from the default XDG location.
|
||||
func Load(path string) (*Config, error) {
|
||||
if path == "" {
|
||||
path = filepath.Join(xdg.ConfigHome, "nix-cache-login", "config.toml")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading config file: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := toml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parsing config file: %w", err)
|
||||
}
|
||||
|
||||
cfg.NetrcPath = os.ExpandEnv(cfg.NetrcPath)
|
||||
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) validate() error {
|
||||
if c.Issuer == "" {
|
||||
return fmt.Errorf("config: issuer is required")
|
||||
}
|
||||
if c.ClientID == "" {
|
||||
return fmt.Errorf("config: client_id is required")
|
||||
}
|
||||
if c.CacheHost == "" {
|
||||
return fmt.Errorf("config: cache_host is required")
|
||||
}
|
||||
if c.NetrcPath == "" {
|
||||
return fmt.Errorf("config: netrc_path is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshTokenPath returns the path to the stored refresh token.
|
||||
func RefreshTokenPath() string {
|
||||
return filepath.Join(xdg.ConfigHome, "nix-cache-login", "refresh-token")
|
||||
}
|
||||
177
internal/config/config_test.go
Normal file
177
internal/config/config_test.go
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadValidConfig(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgFile := filepath.Join(dir, "config.toml")
|
||||
|
||||
content := `
|
||||
issuer = "https://id.example.com/realms/test"
|
||||
client_id = "nix-cache"
|
||||
cache_host = "cache.example.com"
|
||||
netrc_path = "/home/user/.config/nix/netrc"
|
||||
`
|
||||
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(cfgFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Issuer != "https://id.example.com/realms/test" {
|
||||
t.Errorf("issuer = %q, want %q", cfg.Issuer, "https://id.example.com/realms/test")
|
||||
}
|
||||
if cfg.ClientID != "nix-cache" {
|
||||
t.Errorf("client_id = %q, want %q", cfg.ClientID, "nix-cache")
|
||||
}
|
||||
if cfg.CacheHost != "cache.example.com" {
|
||||
t.Errorf("cache_host = %q, want %q", cfg.CacheHost, "cache.example.com")
|
||||
}
|
||||
if cfg.ClientSecret != "" {
|
||||
t.Errorf("client_secret = %q, want empty", cfg.ClientSecret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigWithClientSecret(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgFile := filepath.Join(dir, "config.toml")
|
||||
|
||||
content := `
|
||||
issuer = "https://id.example.com/realms/test"
|
||||
client_id = "nix-cache-server"
|
||||
client_secret = "super-secret"
|
||||
cache_host = "cache.example.com"
|
||||
netrc_path = "/tmp/netrc"
|
||||
`
|
||||
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(cfgFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.ClientSecret != "super-secret" {
|
||||
t.Errorf("client_secret = %q, want %q", cfg.ClientSecret, "super-secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVarExpansionInNetrcPath(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgFile := filepath.Join(dir, "config.toml")
|
||||
|
||||
t.Setenv("TEST_CONFIG_DIR", "/custom/config")
|
||||
|
||||
content := `
|
||||
issuer = "https://id.example.com/realms/test"
|
||||
client_id = "nix-cache"
|
||||
cache_host = "cache.example.com"
|
||||
netrc_path = "$TEST_CONFIG_DIR/nix/netrc"
|
||||
`
|
||||
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(cfgFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.NetrcPath != "/custom/config/nix/netrc" {
|
||||
t.Errorf("netrc_path = %q, want %q", cfg.NetrcPath, "/custom/config/nix/netrc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVarExpansionBraces(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgFile := filepath.Join(dir, "config.toml")
|
||||
|
||||
t.Setenv("MY_HOME", "/home/testuser")
|
||||
|
||||
content := `
|
||||
issuer = "https://id.example.com/realms/test"
|
||||
client_id = "nix-cache"
|
||||
cache_host = "cache.example.com"
|
||||
netrc_path = "${MY_HOME}/.config/nix/netrc"
|
||||
`
|
||||
if err := os.WriteFile(cfgFile, []byte(content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(cfgFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.NetrcPath != "/home/testuser/.config/nix/netrc" {
|
||||
t.Errorf("netrc_path = %q, want %q", cfg.NetrcPath, "/home/testuser/.config/nix/netrc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMissingRequiredFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "missing issuer",
|
||||
content: `client_id = "x"` + "\n" + `cache_host = "x"` + "\n" + `netrc_path = "/tmp/x"`,
|
||||
errMsg: "issuer is required",
|
||||
},
|
||||
{
|
||||
name: "missing client_id",
|
||||
content: `issuer = "https://x"` + "\n" + `cache_host = "x"` + "\n" + `netrc_path = "/tmp/x"`,
|
||||
errMsg: "client_id is required",
|
||||
},
|
||||
{
|
||||
name: "missing cache_host",
|
||||
content: `issuer = "https://x"` + "\n" + `client_id = "x"` + "\n" + `netrc_path = "/tmp/x"`,
|
||||
errMsg: "cache_host is required",
|
||||
},
|
||||
{
|
||||
name: "missing netrc_path",
|
||||
content: `issuer = "https://x"` + "\n" + `client_id = "x"` + "\n" + `cache_host = "x"`,
|
||||
errMsg: "netrc_path is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgFile := filepath.Join(dir, "config.toml")
|
||||
if err := os.WriteFile(cfgFile, []byte(tt.content), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(cfgFile)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("error = %q, want to contain %q", err.Error(), tt.errMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && searchString(s, substr)
|
||||
}
|
||||
|
||||
func searchString(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
136
internal/netrc/netrc.go
Normal file
136
internal/netrc/netrc.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
package netrc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Upsert updates or inserts a machine entry in the netrc file.
|
||||
// Only the password field is written (Nix uses password from netrc as auth).
|
||||
func Upsert(path, machine, password string) error {
|
||||
entries, err := parse(path)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
found := false
|
||||
for i, e := range entries {
|
||||
if e.machine == machine {
|
||||
entries[i].password = password
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
entries = append(entries, entry{machine: machine, password: password})
|
||||
}
|
||||
|
||||
return write(path, entries)
|
||||
}
|
||||
|
||||
// Remove removes the entry for the given machine from the netrc file.
|
||||
func Remove(path, machine string) error {
|
||||
entries, err := parse(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var filtered []entry
|
||||
for _, e := range entries {
|
||||
if e.machine != machine {
|
||||
filtered = append(filtered, e)
|
||||
}
|
||||
}
|
||||
|
||||
return write(path, filtered)
|
||||
}
|
||||
|
||||
// GetPassword returns the password for the given machine, or empty string if not found.
|
||||
func GetPassword(path, machine string) (string, error) {
|
||||
entries, err := parse(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
if e.machine == machine {
|
||||
return e.password, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
machine string
|
||||
password string
|
||||
}
|
||||
|
||||
func parse(path string) ([]entry, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var entries []entry
|
||||
var current *entry
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch fields[0] {
|
||||
case "machine":
|
||||
if current != nil {
|
||||
entries = append(entries, *current)
|
||||
}
|
||||
current = &entry{machine: fields[1]}
|
||||
case "password":
|
||||
if current != nil {
|
||||
current.password = fields[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
if current != nil {
|
||||
entries = append(entries, *current)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func write(path string, entries []entry) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return fmt.Errorf("creating directory for %s: %w", path, err)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for i, e := range entries {
|
||||
if i > 0 {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
fmt.Fprintf(&b, "machine %s\npassword %s\n", e.machine, e.password)
|
||||
}
|
||||
|
||||
return os.WriteFile(path, []byte(b.String()), 0600)
|
||||
}
|
||||
175
internal/netrc/netrc_test.go
Normal file
175
internal/netrc/netrc_test.go
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
package netrc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpsertEmptyFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "netrc")
|
||||
|
||||
if err := Upsert(path, "cache.example.com", "token123"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
pw, err := GetPassword(path, "cache.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pw != "token123" {
|
||||
t.Errorf("password = %q, want %q", pw, "token123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertUpdateExisting(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "netrc")
|
||||
|
||||
initial := "machine other.host\npassword otherpass\n\nmachine cache.example.com\npassword oldtoken\n"
|
||||
if err := os.WriteFile(path, []byte(initial), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Upsert(path, "cache.example.com", "newtoken"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check updated entry
|
||||
pw, err := GetPassword(path, "cache.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pw != "newtoken" {
|
||||
t.Errorf("password = %q, want %q", pw, "newtoken")
|
||||
}
|
||||
|
||||
// Check other entry preserved
|
||||
pw, err = GetPassword(path, "other.host")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pw != "otherpass" {
|
||||
t.Errorf("other password = %q, want %q", pw, "otherpass")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertAppend(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "netrc")
|
||||
|
||||
initial := "machine existing.host\npassword existingpass\n"
|
||||
if err := os.WriteFile(path, []byte(initial), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Upsert(path, "cache.example.com", "newtoken"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
pw, err := GetPassword(path, "cache.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pw != "newtoken" {
|
||||
t.Errorf("password = %q, want %q", pw, "newtoken")
|
||||
}
|
||||
|
||||
pw, err = GetPassword(path, "existing.host")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pw != "existingpass" {
|
||||
t.Errorf("existing password = %q, want %q", pw, "existingpass")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "netrc")
|
||||
|
||||
initial := "machine keep.host\npassword keeppass\n\nmachine remove.host\npassword removepass\n"
|
||||
if err := os.WriteFile(path, []byte(initial), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Remove(path, "remove.host"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
pw, err := GetPassword(path, "remove.host")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pw != "" {
|
||||
t.Errorf("removed entry still has password = %q", pw)
|
||||
}
|
||||
|
||||
pw, err = GetPassword(path, "keep.host")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pw != "keeppass" {
|
||||
t.Errorf("kept password = %q, want %q", pw, "keeppass")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveNonexistentFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "nonexistent")
|
||||
|
||||
if err := Remove(path, "anything"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPasswordNoFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "nonexistent")
|
||||
|
||||
pw, err := GetPassword(path, "anything")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pw != "" {
|
||||
t.Errorf("password = %q, want empty", pw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPasswordNotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "netrc")
|
||||
|
||||
content := "machine other.host\npassword otherpass\n"
|
||||
if err := os.WriteFile(path, []byte(content), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pw, err := GetPassword(path, "missing.host")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pw != "" {
|
||||
t.Errorf("password = %q, want empty", pw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilePermissions(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "netrc")
|
||||
|
||||
if err := Upsert(path, "cache.example.com", "token"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("stat error: %v", err)
|
||||
}
|
||||
|
||||
perm := info.Mode().Perm()
|
||||
if perm != 0600 {
|
||||
t.Errorf("file permissions = %o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
25
internal/pkce/pkce.go
Normal file
25
internal/pkce/pkce.go
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
package pkce
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
const verifierLength = 43
|
||||
|
||||
// Generate creates a PKCE code verifier and its S256 challenge.
|
||||
func Generate() (verifier, challenge string, err error) {
|
||||
// Generate random bytes and encode to URL-safe base64 (no padding)
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(buf)
|
||||
|
||||
// Derive challenge: base64url(sha256(verifier))
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
63
internal/pkce/pkce_test.go
Normal file
63
internal/pkce/pkce_test.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
package pkce
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateVerifierLength(t *testing.T) {
|
||||
verifier, _, err := Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(verifier) < 43 || len(verifier) > 128 {
|
||||
t.Errorf("verifier length = %d, want 43-128", len(verifier))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateVerifierCharacterSet(t *testing.T) {
|
||||
verifier, _, err := Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// RFC 7636: unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
|
||||
valid := regexp.MustCompile(`^[A-Za-z0-9\-._~]+$`)
|
||||
if !valid.MatchString(verifier) {
|
||||
t.Errorf("verifier contains invalid characters: %q", verifier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateChallengeCorrectness(t *testing.T) {
|
||||
verifier, challenge, err := Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Recompute challenge from verifier
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
expected := base64.RawURLEncoding.EncodeToString(h[:])
|
||||
|
||||
if challenge != expected {
|
||||
t.Errorf("challenge = %q, want %q", challenge, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateUniqueness(t *testing.T) {
|
||||
v1, _, err := Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
v2, _, err := Generate()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if v1 == v2 {
|
||||
t.Error("two Generate() calls produced identical verifiers")
|
||||
}
|
||||
}
|
||||
46
internal/token/jwt.go
Normal file
46
internal/token/jwt.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
package token
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DecodePayload decodes the payload (claims) of a JWT without verifying the signature.
|
||||
func DecodePayload(tokenStr string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenStr, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decoding JWT payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling JWT payload: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// ExpiryInfo extracts the expiry time and remaining duration from JWT claims.
|
||||
func ExpiryInfo(claims map[string]interface{}) (exp time.Time, remaining time.Duration) {
|
||||
expVal, ok := claims["exp"]
|
||||
if !ok {
|
||||
return time.Time{}, 0
|
||||
}
|
||||
|
||||
expFloat, ok := expVal.(float64)
|
||||
if !ok {
|
||||
return time.Time{}, 0
|
||||
}
|
||||
|
||||
exp = time.Unix(int64(expFloat), 0)
|
||||
remaining = time.Until(exp)
|
||||
return exp, remaining
|
||||
}
|
||||
92
internal/token/jwt_test.go
Normal file
92
internal/token/jwt_test.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
package token
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// buildTestJWT builds a JWT string with the given payload (no real signature).
|
||||
func buildTestJWT(claims map[string]interface{}) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
|
||||
payload, _ := json.Marshal(claims)
|
||||
payloadEnc := base64.RawURLEncoding.EncodeToString(payload)
|
||||
sig := base64.RawURLEncoding.EncodeToString([]byte("fakesig"))
|
||||
return header + "." + payloadEnc + "." + sig
|
||||
}
|
||||
|
||||
func TestDecodePayload(t *testing.T) {
|
||||
claims := map[string]interface{}{
|
||||
"iss": "https://id.example.com/realms/test",
|
||||
"sub": "user123",
|
||||
"exp": float64(1700000000),
|
||||
"aud": "nix-cache",
|
||||
}
|
||||
|
||||
jwt := buildTestJWT(claims)
|
||||
|
||||
decoded, err := DecodePayload(jwt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if decoded["iss"] != "https://id.example.com/realms/test" {
|
||||
t.Errorf("iss = %v, want %v", decoded["iss"], "https://id.example.com/realms/test")
|
||||
}
|
||||
if decoded["sub"] != "user123" {
|
||||
t.Errorf("sub = %v, want %v", decoded["sub"], "user123")
|
||||
}
|
||||
if decoded["aud"] != "nix-cache" {
|
||||
t.Errorf("aud = %v, want %v", decoded["aud"], "nix-cache")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiryInfo(t *testing.T) {
|
||||
futureExp := time.Now().Add(1 * time.Hour).Unix()
|
||||
claims := map[string]interface{}{
|
||||
"exp": float64(futureExp),
|
||||
}
|
||||
|
||||
exp, remaining := ExpiryInfo(claims)
|
||||
|
||||
if exp.Unix() != futureExp {
|
||||
t.Errorf("exp = %v, want %v", exp.Unix(), futureExp)
|
||||
}
|
||||
if remaining < 59*time.Minute || remaining > 61*time.Minute {
|
||||
t.Errorf("remaining = %v, expected ~1 hour", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiryInfoPast(t *testing.T) {
|
||||
pastExp := time.Now().Add(-1 * time.Hour).Unix()
|
||||
claims := map[string]interface{}{
|
||||
"exp": float64(pastExp),
|
||||
}
|
||||
|
||||
_, remaining := ExpiryInfo(claims)
|
||||
|
||||
if remaining >= 0 {
|
||||
t.Errorf("remaining = %v, expected negative (expired)", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodePayloadMalformed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"no dots", "nodots"},
|
||||
{"one dot", "one.dot"},
|
||||
{"empty payload", "header..sig"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := DecodePayload(tt.token)
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue