package config import ( "os" "path/filepath" "testing" "time" "github.com/spf13/viper" ) func TestNewViperConfig(t *testing.T) { t.Parallel() v := viper.New() v.Set("test.key", "test.value") cfg := NewViperConfig(v) if cfg == nil { t.Fatal("NewViperConfig returned nil") } // Verify it implements the interface (compile-time check) _ = cfg } func TestViperConfig_Get(t *testing.T) { t.Parallel() tests := []struct { name string key string setValue any want any }{ { name: "string value", key: "test.string", setValue: "test", want: "test", }, { name: "int value", key: "test.int", setValue: 42, want: 42, }, { name: "bool value", key: "test.bool", setValue: true, want: true, }, { name: "non-existent key", key: "test.missing", setValue: nil, want: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() v := viper.New() if tt.setValue != nil { v.Set(tt.key, tt.setValue) } cfg := NewViperConfig(v) got := cfg.Get(tt.key) if got != tt.want { t.Errorf("Get(%q) = %v, want %v", tt.key, got, tt.want) } }) } } func TestViperConfig_GetString(t *testing.T) { t.Parallel() tests := []struct { name string key string setValue string want string }{ { name: "valid string", key: "test.string", setValue: "hello", want: "hello", }, { name: "empty string", key: "test.empty", setValue: "", want: "", }, { name: "non-existent key", key: "test.missing", setValue: "", want: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() v := viper.New() v.Set(tt.key, tt.setValue) cfg := NewViperConfig(v) got := cfg.GetString(tt.key) if got != tt.want { t.Errorf("GetString(%q) = %q, want %q", tt.key, got, tt.want) } }) } } func TestViperConfig_GetInt(t *testing.T) { t.Parallel() tests := []struct { name string key string setValue int want int }{ { name: "valid int", key: "test.int", setValue: 42, want: 42, }, { name: "zero value", key: "test.zero", setValue: 0, want: 0, }, { name: "non-existent key", key: "test.missing", setValue: 0, want: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() v := viper.New() v.Set(tt.key, tt.setValue) cfg := NewViperConfig(v) got := cfg.GetInt(tt.key) if got != tt.want { t.Errorf("GetInt(%q) = %d, want %d", tt.key, got, tt.want) } }) } } func TestViperConfig_GetBool(t *testing.T) { t.Parallel() tests := []struct { name string key string setValue bool want bool }{ { name: "true value", key: "test.bool", setValue: true, want: true, }, { name: "false value", key: "test.bool", setValue: false, want: false, }, { name: "non-existent key", key: "test.missing", setValue: false, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() v := viper.New() v.Set(tt.key, tt.setValue) cfg := NewViperConfig(v) got := cfg.GetBool(tt.key) if got != tt.want { t.Errorf("GetBool(%q) = %v, want %v", tt.key, got, tt.want) } }) } } func TestViperConfig_GetStringSlice(t *testing.T) { t.Parallel() tests := []struct { name string key string setValue []string want []string }{ { name: "valid slice", key: "test.slice", setValue: []string{"a", "b", "c"}, want: []string{"a", "b", "c"}, }, { name: "empty slice", key: "test.empty", setValue: []string{}, want: []string{}, }, { name: "non-existent key", key: "test.missing", setValue: nil, want: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() v := viper.New() v.Set(tt.key, tt.setValue) cfg := NewViperConfig(v) got := cfg.GetStringSlice(tt.key) if len(got) != len(tt.want) { t.Errorf("GetStringSlice(%q) length = %d, want %d", tt.key, len(got), len(tt.want)) return } for i := range got { if got[i] != tt.want[i] { t.Errorf("GetStringSlice(%q)[%d] = %q, want %q", tt.key, i, got[i], tt.want[i]) } } }) } } func TestViperConfig_GetDuration(t *testing.T) { t.Parallel() tests := []struct { name string key string setValue time.Duration want time.Duration }{ { name: "valid duration", key: "test.duration", setValue: 30 * time.Second, want: 30 * time.Second, }, { name: "zero duration", key: "test.zero", setValue: 0, want: 0, }, { name: "non-existent key", key: "test.missing", setValue: 0, want: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() v := viper.New() v.Set(tt.key, tt.setValue) cfg := NewViperConfig(v) got := cfg.GetDuration(tt.key) if got != tt.want { t.Errorf("GetDuration(%q) = %v, want %v", tt.key, got, tt.want) } }) } } func TestViperConfig_IsSet(t *testing.T) { t.Parallel() tests := []struct { name string key string setValue any want bool }{ { name: "set key", key: "test.key", setValue: "value", want: true, }, { name: "non-existent key", key: "test.missing", setValue: nil, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() v := viper.New() if tt.setValue != nil { v.Set(tt.key, tt.setValue) } cfg := NewViperConfig(v) got := cfg.IsSet(tt.key) if got != tt.want { t.Errorf("IsSet(%q) = %v, want %v", tt.key, got, tt.want) } }) } } func TestViperConfig_Unmarshal(t *testing.T) { t.Parallel() type Config struct { Server struct { Port int `mapstructure:"port"` Host string `mapstructure:"host"` } `mapstructure:"server"` Logging struct { Level string `mapstructure:"level"` Format string `mapstructure:"format"` } `mapstructure:"logging"` } v := viper.New() v.Set("server.port", 8080) v.Set("server.host", "localhost") v.Set("logging.level", "debug") v.Set("logging.format", "json") cfg := NewViperConfig(v) var result Config err := cfg.Unmarshal(&result) if err != nil { t.Fatalf("Unmarshal failed: %v", err) } if result.Server.Port != 8080 { t.Errorf("Server.Port = %d, want 8080", result.Server.Port) } if result.Server.Host != "localhost" { t.Errorf("Server.Host = %q, want %q", result.Server.Host, "localhost") } if result.Logging.Level != "debug" { t.Errorf("Logging.Level = %q, want %q", result.Logging.Level, "debug") } } func TestLoadConfig_Default(t *testing.T) { // Note: Cannot run in parallel due to os.Chdir() being process-global // Create a temporary config directory tmpDir := t.TempDir() configDir := filepath.Join(tmpDir, "config") if err := os.MkdirAll(configDir, 0750); err != nil { t.Fatalf("Failed to create config dir: %v", err) } // Create a default.yaml file defaultYAML := `server: port: 8080 host: "localhost" logging: level: "info" format: "json" ` if err := os.WriteFile(filepath.Join(configDir, "default.yaml"), []byte(defaultYAML), 0600); err != nil { t.Fatalf("Failed to write default.yaml: %v", err) } // Change to temp directory temporarily originalDir, err := os.Getwd() if err != nil { t.Fatalf("Failed to get current directory: %v", err) } defer func() { if err := os.Chdir(originalDir); err != nil { t.Logf("Failed to restore directory: %v", err) } }() if err := os.Chdir(tmpDir); err != nil { t.Fatalf("Failed to change directory: %v", err) } cfg, err := LoadConfig("") if err != nil { t.Fatalf("LoadConfig failed: %v", err) } if cfg == nil { t.Fatal("LoadConfig returned nil") } // Verify config values if cfg.GetString("server.host") != "localhost" { t.Errorf("server.host = %q, want %q", cfg.GetString("server.host"), "localhost") } if cfg.GetInt("server.port") != 8080 { t.Errorf("server.port = %d, want 8080", cfg.GetInt("server.port")) } } func TestLoadConfig_WithEnvironment(t *testing.T) { // Note: Cannot run in parallel due to os.Chdir() being process-global // Create a temporary config directory tmpDir := t.TempDir() configDir := filepath.Join(tmpDir, "config") if err := os.MkdirAll(configDir, 0750); err != nil { t.Fatalf("Failed to create config dir: %v", err) } // Create default.yaml defaultYAML := `server: port: 8080 host: "localhost" logging: level: "info" ` if err := os.WriteFile(filepath.Join(configDir, "default.yaml"), []byte(defaultYAML), 0600); err != nil { t.Fatalf("Failed to write default.yaml: %v", err) } // Create development.yaml devYAML := `server: port: 3000 logging: level: "debug" ` if err := os.WriteFile(filepath.Join(configDir, "development.yaml"), []byte(devYAML), 0600); err != nil { t.Fatalf("Failed to write development.yaml: %v", err) } // Change to temp directory temporarily originalDir, err := os.Getwd() if err != nil { t.Fatalf("Failed to get current directory: %v", err) } defer func() { if err := os.Chdir(originalDir); err != nil { t.Logf("Failed to restore directory: %v", err) } }() if err := os.Chdir(tmpDir); err != nil { t.Fatalf("Failed to change directory: %v", err) } cfg, err := LoadConfig("development") if err != nil { t.Fatalf("LoadConfig failed: %v", err) } if cfg == nil { t.Fatal("LoadConfig returned nil") } // Development config should override port if cfg.GetInt("server.port") != 3000 { t.Errorf("server.port = %d, want 3000", cfg.GetInt("server.port")) } // Development config should override logging level if cfg.GetString("logging.level") != "debug" { t.Errorf("logging.level = %q, want %q", cfg.GetString("logging.level"), "debug") } // Default host should still be present if cfg.GetString("server.host") != "localhost" { t.Errorf("server.host = %q, want %q", cfg.GetString("server.host"), "localhost") } } func TestLoadConfig_MissingDefaultFile(t *testing.T) { // Note: Cannot run in parallel due to os.Chdir() being process-global // Create a temporary directory without config files tmpDir := t.TempDir() // Change to temp directory temporarily originalDir, err := os.Getwd() if err != nil { t.Fatalf("Failed to get current directory: %v", err) } defer func() { if err := os.Chdir(originalDir); err != nil { t.Logf("Failed to restore directory: %v", err) } }() if err := os.Chdir(tmpDir); err != nil { t.Fatalf("Failed to change directory: %v", err) } _, err = LoadConfig("") if err == nil { t.Fatal("LoadConfig should fail when default.yaml is missing") } }