diff --git a/cmd/thv/app/otel.go b/cmd/thv/app/otel.go index cc955befc..eb3c80bd1 100644 --- a/cmd/thv/app/otel.go +++ b/cmd/thv/app/otel.go @@ -2,8 +2,6 @@ package app import ( "fmt" - "strconv" - "strings" "github.com/spf13/cobra" @@ -17,561 +15,119 @@ var OtelCmd = &cobra.Command{ Long: "Configure OpenTelemetry settings for observability and monitoring of MCP servers.", } -var setOtelEndpointCmd = &cobra.Command{ - Use: "set-endpoint ", - Short: "Set the OpenTelemetry endpoint URL", - Long: `Set the OpenTelemetry OTLP endpoint URL for tracing and metrics. - -This endpoint will be used by default when running MCP servers unless overridden by the --otel-endpoint flag. - -Example: - - thv config otel set-endpoint https://api.honeycomb.io`, - Args: cobra.ExactArgs(1), - RunE: setOtelEndpointCmdFunc, -} - -var getOtelEndpointCmd = &cobra.Command{ - Use: "get-endpoint", - Short: "Get the currently configured OpenTelemetry endpoint", - Long: "Display the OpenTelemetry endpoint URL that is currently configured.", - RunE: getOtelEndpointCmdFunc, -} - -var unsetOtelEndpointCmd = &cobra.Command{ - Use: "unset-endpoint", - Short: "Remove the configured OpenTelemetry endpoint", - Long: "Remove the OpenTelemetry endpoint configuration.", - RunE: unsetOtelEndpointCmdFunc, -} - -var setOtelMetricsEnabledCmd = &cobra.Command{ - Use: "set-metrics-enabled ", - Short: "Set the OpenTelemetry metrics export to enabled", - Long: `Set the OpenTelemetry metrics flag to enable to export metrics to an OTel collector. - - thv config otel set-metrics-enabled true`, - Args: cobra.ExactArgs(1), - RunE: setOtelMetricsEnabledCmdFunc, -} - -var getOtelMetricsEnabledCmd = &cobra.Command{ - Use: "get-metrics-enabled", - Short: "Get the currently configured OpenTelemetry metrics export flag", - Long: "Display the OpenTelemetry metrics export flag that is currently configured.", - RunE: getOtelMetricsEnabledCmdFunc, -} - -var unsetOtelMetricsEnabledCmd = &cobra.Command{ - Use: "unset-metrics-enabled", - Short: "Remove the configured OpenTelemetry metrics export flag", - Long: "Remove the OpenTelemetry metrics export flag configuration.", - RunE: unsetOtelMetricsEnabledCmdFunc, -} - -var setOtelTracingEnabledCmd = &cobra.Command{ - Use: "set-tracing-enabled ", - Short: "Set the OpenTelemetry tracing export to enabled", - Long: `Set the OpenTelemetry tracing flag to enable to export traces to an OTel collector. - - thv config otel set-tracing-enabled true`, - Args: cobra.ExactArgs(1), - RunE: setOtelTracingEnabledCmdFunc, -} - -var getOtelTracingEnabledCmd = &cobra.Command{ - Use: "get-tracing-enabled", - Short: "Get the currently configured OpenTelemetry tracing export flag", - Long: "Display the OpenTelemetry tracing export flag that is currently configured.", - RunE: getOtelTracingEnabledCmdFunc, -} - -var unsetOtelTracingEnabledCmd = &cobra.Command{ - Use: "unset-tracing-enabled", - Short: "Remove the configured OpenTelemetry tracing export flag", - Long: "Remove the OpenTelemetry tracing export flag configuration.", - RunE: unsetOtelTracingEnabledCmdFunc, -} - -var setOtelSamplingRateCmd = &cobra.Command{ - Use: "set-sampling-rate ", - Short: "Set the OpenTelemetry sampling rate", - Long: `Set the OpenTelemetry trace sampling rate (between 0.0 and 1.0). - -This sampling rate will be used by default when running MCP servers unless overridden by the --otel-sampling-rate flag. - -Example: - - thv config otel set-sampling-rate 0.1`, - Args: cobra.ExactArgs(1), - RunE: setOtelSamplingRateCmdFunc, -} - -var getOtelSamplingRateCmd = &cobra.Command{ - Use: "get-sampling-rate", - Short: "Get the currently configured OpenTelemetry sampling rate", - Long: "Display the OpenTelemetry sampling rate that is currently configured.", - RunE: getOtelSamplingRateCmdFunc, -} - -var unsetOtelSamplingRateCmd = &cobra.Command{ - Use: "unset-sampling-rate", - Short: "Remove the configured OpenTelemetry sampling rate", - Long: "Remove the OpenTelemetry sampling rate configuration.", - RunE: unsetOtelSamplingRateCmdFunc, -} - -var setOtelEnvVarsCmd = &cobra.Command{ - Use: "set-env-vars ", - Short: "Set the OpenTelemetry environment variables", - Long: `Set the list of environment variable names to include in OpenTelemetry spans. - -These environment variables will be used by default when running MCP servers unless overridden by the --otel-env-vars flag. - -Example: - - thv config otel set-env-vars USER,HOME,PATH`, - Args: cobra.ExactArgs(1), - RunE: setOtelEnvVarsCmdFunc, -} - -var getOtelEnvVarsCmd = &cobra.Command{ - Use: "get-env-vars", - Short: "Get the currently configured OpenTelemetry environment variables", - Long: "Display the OpenTelemetry environment variables that are currently configured.", - RunE: getOtelEnvVarsCmdFunc, -} - -var unsetOtelEnvVarsCmd = &cobra.Command{ - Use: "unset-env-vars", - Short: "Remove the configured OpenTelemetry environment variables", - Long: "Remove the OpenTelemetry environment variables configuration.", - RunE: unsetOtelEnvVarsCmdFunc, -} - -var setOtelInsecureCmd = &cobra.Command{ - Use: "set-insecure ", - Short: "Set the OpenTelemetry insecure transport flag", - Long: `Set the OpenTelemetry insecure flag to enable HTTP instead of HTTPS for OTLP endpoints. - - thv config otel set-insecure true`, - Args: cobra.ExactArgs(1), - RunE: setOtelInsecureCmdFunc, -} - -var getOtelInsecureCmd = &cobra.Command{ - Use: "get-insecure", - Short: "Get the currently configured OpenTelemetry insecure transport flag", - Long: "Display the OpenTelemetry insecure transport flag that is currently configured.", - RunE: getOtelInsecureCmdFunc, -} - -var unsetOtelInsecureCmd = &cobra.Command{ - Use: "unset-insecure", - Short: "Remove the configured OpenTelemetry insecure transport flag", - Long: "Remove the OpenTelemetry insecure transport flag configuration.", - RunE: unsetOtelInsecureCmdFunc, -} - -var setOtelEnablePrometheusMetricsPathCmd = &cobra.Command{ - Use: "set-enable-prometheus-metrics-path ", - Short: "Set the OpenTelemetry Prometheus metrics path flag", - Long: `Set the OpenTelemetry Prometheus metrics path flag to enable /metrics endpoint. - - thv config otel set-enable-prometheus-metrics-path true`, - Args: cobra.ExactArgs(1), - RunE: setOtelEnablePrometheusMetricsPathCmdFunc, -} - -var getOtelEnablePrometheusMetricsPathCmd = &cobra.Command{ - Use: "get-enable-prometheus-metrics-path", - Short: "Get the currently configured OpenTelemetry Prometheus metrics path flag", - Long: "Display the OpenTelemetry Prometheus metrics path flag that is currently configured.", - RunE: getOtelEnablePrometheusMetricsPathCmdFunc, -} - -var unsetOtelEnablePrometheusMetricsPathCmd = &cobra.Command{ - Use: "unset-enable-prometheus-metrics-path", - Short: "Remove the configured OpenTelemetry Prometheus metrics path flag", - Long: "Remove the OpenTelemetry Prometheus metrics path flag configuration.", - RunE: unsetOtelEnablePrometheusMetricsPathCmdFunc, -} - -// init sets up the OTEL command hierarchy -func init() { - // Add OTEL subcommands to otel command - OtelCmd.AddCommand(setOtelEndpointCmd) - OtelCmd.AddCommand(getOtelEndpointCmd) - OtelCmd.AddCommand(unsetOtelEndpointCmd) - OtelCmd.AddCommand(setOtelMetricsEnabledCmd) - OtelCmd.AddCommand(getOtelMetricsEnabledCmd) - OtelCmd.AddCommand(unsetOtelMetricsEnabledCmd) - OtelCmd.AddCommand(setOtelTracingEnabledCmd) - OtelCmd.AddCommand(getOtelTracingEnabledCmd) - OtelCmd.AddCommand(unsetOtelTracingEnabledCmd) - OtelCmd.AddCommand(setOtelSamplingRateCmd) - OtelCmd.AddCommand(getOtelSamplingRateCmd) - OtelCmd.AddCommand(unsetOtelSamplingRateCmd) - OtelCmd.AddCommand(setOtelEnvVarsCmd) - OtelCmd.AddCommand(getOtelEnvVarsCmd) - OtelCmd.AddCommand(unsetOtelEnvVarsCmd) - OtelCmd.AddCommand(setOtelInsecureCmd) - OtelCmd.AddCommand(getOtelInsecureCmd) - OtelCmd.AddCommand(unsetOtelInsecureCmd) - OtelCmd.AddCommand(setOtelEnablePrometheusMetricsPathCmd) - OtelCmd.AddCommand(getOtelEnablePrometheusMetricsPathCmd) - OtelCmd.AddCommand(unsetOtelEnablePrometheusMetricsPathCmd) -} - -func setOtelEndpointCmdFunc(_ *cobra.Command, args []string) error { - endpoint := args[0] - - // The endpoint should not start with http:// or https:// - if endpoint != "" && (strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://")) { - return fmt.Errorf("endpoint URL should not start with http:// or https://") - } - - // Update the configuration - err := config.UpdateConfig(func(c *config.Config) { - c.OTEL.Endpoint = endpoint - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Printf("Successfully set OpenTelemetry endpoint: %s\n", endpoint) - return nil -} - -func getOtelEndpointCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - if cfg.OTEL.Endpoint == "" { - fmt.Println("No OpenTelemetry endpoint is currently configured.") - return nil - } - - fmt.Printf("Current OpenTelemetry endpoint: %s\n", cfg.OTEL.Endpoint) - return nil -} - -func unsetOtelEndpointCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - if cfg.OTEL.Endpoint == "" { - fmt.Println("No OpenTelemetry endpoint is currently configured.") - return nil - } - - // Update the configuration - err := config.UpdateConfig(func(c *config.Config) { - c.OTEL.Endpoint = "" - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Println("Successfully removed OpenTelemetry endpoint configuration.") - return nil -} - -func setOtelSamplingRateCmdFunc(_ *cobra.Command, args []string) error { - rate, err := strconv.ParseFloat(args[0], 64) - if err != nil { - return fmt.Errorf("invalid sampling rate format: %w", err) - } - - // Validate the rate - if rate < 0.0 || rate > 1.0 { - return fmt.Errorf("sampling rate must be between 0.0 and 1.0") - } - - // Update the configuration - err = config.UpdateConfig(func(c *config.Config) { - c.OTEL.SamplingRate = rate - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Printf("Successfully set OpenTelemetry sampling rate: %f\n", rate) - return nil -} - -func getOtelSamplingRateCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - if cfg.OTEL.SamplingRate == 0.0 { - fmt.Println("No OpenTelemetry sampling rate is currently configured.") - return nil - } - - fmt.Printf("Current OpenTelemetry sampling rate: %f\n", cfg.OTEL.SamplingRate) - return nil -} - -func unsetOtelSamplingRateCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - if cfg.OTEL.SamplingRate == 0.0 { - fmt.Println("No OpenTelemetry sampling rate is currently configured.") - return nil - } - - // Update the configuration - err := config.UpdateConfig(func(c *config.Config) { - c.OTEL.SamplingRate = 0.0 - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Println("Successfully removed OpenTelemetry sampling rate configuration.") - return nil -} - -func setOtelEnvVarsCmdFunc(_ *cobra.Command, args []string) error { - vars := strings.Split(args[0], ",") - - // Trim whitespace from each variable name - for i, varName := range vars { - vars[i] = strings.TrimSpace(varName) - } - - // Update the configuration - err := config.UpdateConfig(func(c *config.Config) { - c.OTEL.EnvVars = vars - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Printf("Successfully set OpenTelemetry environment variables: %v\n", vars) - return nil -} - -func getOtelEnvVarsCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - if len(cfg.OTEL.EnvVars) == 0 { - fmt.Println("No OpenTelemetry environment variables are currently configured.") - return nil - } - - fmt.Printf("Current OpenTelemetry environment variables: %v\n", cfg.OTEL.EnvVars) - return nil -} - -func unsetOtelEnvVarsCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - if len(cfg.OTEL.EnvVars) == 0 { - fmt.Println("No OpenTelemetry environment variables are currently configured.") - return nil - } - - // Update the configuration - err := config.UpdateConfig(func(c *config.Config) { - c.OTEL.EnvVars = []string{} - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Println("Successfully removed OpenTelemetry environment variables configuration.") - return nil -} - -func setOtelMetricsEnabledCmdFunc(_ *cobra.Command, args []string) error { - enabled, err := strconv.ParseBool(args[0]) - if err != nil { - return fmt.Errorf("invalid boolean value for metrics enabled flag: %w", err) - } - - // Update the configuration - err = config.UpdateConfig(func(c *config.Config) { - c.OTEL.MetricsEnabled = enabled - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Printf("Successfully set OpenTelemetry metrics enabled: %t\n", enabled) - return nil -} - -func getOtelMetricsEnabledCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - fmt.Printf("Current OpenTelemetry metrics enabled: %t\n", cfg.OTEL.MetricsEnabled) - return nil -} - -func unsetOtelMetricsEnabledCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - if !cfg.OTEL.MetricsEnabled { - fmt.Println("OpenTelemetry metrics enabled is already disabled.") - return nil - } - - // Update the configuration - err := config.UpdateConfig(func(c *config.Config) { - c.OTEL.MetricsEnabled = false - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Println("Successfully disabled OpenTelemetry metrics enabled configuration.") - return nil -} - -func setOtelTracingEnabledCmdFunc(_ *cobra.Command, args []string) error { - enabled, err := strconv.ParseBool(args[0]) - if err != nil { - return fmt.Errorf("invalid boolean value for tracing enabled flag: %w", err) +// createOTELSetCommand creates a generic set command for an OTEL field +func createOTELSetCommand(fieldName, commandName, description, example string) *cobra.Command { + return &cobra.Command{ + Use: fmt.Sprintf("set-%s <%s>", commandName, commandName), + Short: fmt.Sprintf("Set the OpenTelemetry %s", description), + Long: fmt.Sprintf("Set the OpenTelemetry %s.\n\nExample:\n\n\tthv config otel set-%s %s", description, commandName, example), + Args: cobra.ExactArgs(1), + RunE: func(_ *cobra.Command, args []string) error { + provider := config.NewDefaultProvider() + err := config.SetConfigField(provider, fieldName, args[0]) + if err != nil { + return err + } + fmt.Printf("Successfully set OpenTelemetry %s: %s\n", description, args[0]) + return nil + }, + } +} + +// createOTELGetCommand creates a generic get command for an OTEL field +func createOTELGetCommand(fieldName, commandName, description string) *cobra.Command { + return &cobra.Command{ + Use: fmt.Sprintf("get-%s", commandName), + Short: fmt.Sprintf("Get the currently configured OpenTelemetry %s", description), + Long: fmt.Sprintf("Display the OpenTelemetry %s that is currently configured.", description), + RunE: func(_ *cobra.Command, _ []string) error { + provider := config.NewDefaultProvider() + value, isSet, err := config.GetConfigField(provider, fieldName) + if err != nil { + return err + } + + if !isSet { + fmt.Printf("No OpenTelemetry %s is currently configured.\n", description) + return nil + } + + fmt.Printf("Current OpenTelemetry %s: %s\n", description, value) + return nil + }, + } +} + +// createOTELUnsetCommand creates a generic unset command for an OTEL field +func createOTELUnsetCommand(fieldName, commandName, description string) *cobra.Command { + return &cobra.Command{ + Use: fmt.Sprintf("unset-%s", commandName), + Short: fmt.Sprintf("Remove the configured OpenTelemetry %s", description), + Long: fmt.Sprintf("Remove the OpenTelemetry %s configuration.", description), + RunE: func(_ *cobra.Command, _ []string) error { + provider := config.NewDefaultProvider() + + // Check if it's set before unsetting + _, isSet, err := config.GetConfigField(provider, fieldName) + if err != nil { + return err + } + + if !isSet { + fmt.Printf("No OpenTelemetry %s is currently configured.\n", description) + return nil + } + + err = config.UnsetConfigField(provider, fieldName) + if err != nil { + return err + } + + fmt.Printf("Successfully removed OpenTelemetry %s configuration.\n", description) + return nil + }, } - - // Update the configuration - err = config.UpdateConfig(func(c *config.Config) { - c.OTEL.TracingEnabled = enabled - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Printf("Successfully set OpenTelemetry tracing enabled: %t\n", enabled) - return nil -} - -func getOtelTracingEnabledCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - fmt.Printf("Current OpenTelemetry tracing enabled: %t\n", cfg.OTEL.TracingEnabled) - return nil -} - -func unsetOtelTracingEnabledCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - if !cfg.OTEL.TracingEnabled { - fmt.Println("OpenTelemetry tracing enabled is already disabled.") - return nil - } - - // Update the configuration - err := config.UpdateConfig(func(c *config.Config) { - c.OTEL.TracingEnabled = false - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Println("Successfully disabled OpenTelemetry tracing enabled configuration.") - return nil -} - -func setOtelInsecureCmdFunc(_ *cobra.Command, args []string) error { - enabled, err := strconv.ParseBool(args[0]) - if err != nil { - return fmt.Errorf("invalid boolean value for insecure flag: %w", err) - } - - // Update the configuration - err = config.UpdateConfig(func(c *config.Config) { - c.OTEL.Insecure = enabled - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Printf("Successfully set OpenTelemetry insecure transport: %t\n", enabled) - return nil -} - -func getOtelInsecureCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - fmt.Printf("Current OpenTelemetry insecure transport: %t\n", cfg.OTEL.Insecure) - return nil } -func unsetOtelInsecureCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - if !cfg.OTEL.Insecure { - fmt.Println("OpenTelemetry insecure transport is already disabled.") - return nil - } - - // Update the configuration - err := config.UpdateConfig(func(c *config.Config) { - c.OTEL.Insecure = false - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Println("Successfully disabled OpenTelemetry insecure transport configuration.") - return nil -} - -func setOtelEnablePrometheusMetricsPathCmdFunc(_ *cobra.Command, args []string) error { - enabled, err := strconv.ParseBool(args[0]) - if err != nil { - return fmt.Errorf("invalid boolean value for Prometheus metrics path flag: %w", err) - } - - // Update the configuration - err = config.UpdateConfig(func(c *config.Config) { - c.OTEL.EnablePrometheusMetricsPath = enabled - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Printf("Successfully set Prometheus metrics path: %t\n", enabled) - return nil -} - -func getOtelEnablePrometheusMetricsPathCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - fmt.Printf("Current Prometheus metrics path flag: %t\n", cfg.OTEL.EnablePrometheusMetricsPath) - return nil -} - -func unsetOtelEnablePrometheusMetricsPathCmdFunc(_ *cobra.Command, _ []string) error { - configProvider := config.NewDefaultProvider() - cfg := configProvider.GetConfig() - - if !cfg.OTEL.EnablePrometheusMetricsPath { - fmt.Println("Prometheus metrics path is already disabled.") - return nil - } - - // Update the configuration - err := config.UpdateConfig(func(c *config.Config) { - c.OTEL.EnablePrometheusMetricsPath = false - }) - if err != nil { - return fmt.Errorf("failed to update configuration: %w", err) - } - - fmt.Println("Successfully disabled the Prometheus metrics path configuration.") - return nil +func init() { + // Endpoint commands + OtelCmd.AddCommand(createOTELSetCommand("otel-endpoint", "endpoint", "endpoint URL", "https://api.honeycomb.io")) + OtelCmd.AddCommand(createOTELGetCommand("otel-endpoint", "endpoint", "endpoint")) + OtelCmd.AddCommand(createOTELUnsetCommand("otel-endpoint", "endpoint", "endpoint")) + + // Sampling rate commands + OtelCmd.AddCommand(createOTELSetCommand("otel-sampling-rate", "sampling-rate", "sampling rate", "0.5")) + OtelCmd.AddCommand(createOTELGetCommand("otel-sampling-rate", "sampling-rate", "sampling rate")) + OtelCmd.AddCommand(createOTELUnsetCommand("otel-sampling-rate", "sampling-rate", "sampling rate")) + + // Environment variables commands + OtelCmd.AddCommand(createOTELSetCommand("otel-env-vars", "env-vars", "environment variables", "VAR1,VAR2,VAR3")) + OtelCmd.AddCommand(createOTELGetCommand("otel-env-vars", "env-vars", "environment variables")) + OtelCmd.AddCommand(createOTELUnsetCommand("otel-env-vars", "env-vars", "environment variables")) + + // Metrics enabled commands + OtelCmd.AddCommand(createOTELSetCommand("otel-metrics-enabled", "metrics-enabled", "metrics export flag", "true")) + OtelCmd.AddCommand(createOTELGetCommand("otel-metrics-enabled", "metrics-enabled", "metrics export flag")) + OtelCmd.AddCommand(createOTELUnsetCommand("otel-metrics-enabled", "metrics-enabled", "metrics export flag")) + + // Tracing enabled commands + OtelCmd.AddCommand(createOTELSetCommand("otel-tracing-enabled", "tracing-enabled", "tracing export flag", "true")) + OtelCmd.AddCommand(createOTELGetCommand("otel-tracing-enabled", "tracing-enabled", "tracing export flag")) + OtelCmd.AddCommand(createOTELUnsetCommand("otel-tracing-enabled", "tracing-enabled", "tracing export flag")) + + // Insecure commands + OtelCmd.AddCommand(createOTELSetCommand("otel-insecure", "insecure", "insecure connection flag", "true")) + OtelCmd.AddCommand(createOTELGetCommand("otel-insecure", "insecure", "insecure connection flag")) + OtelCmd.AddCommand(createOTELUnsetCommand("otel-insecure", "insecure", "insecure connection flag")) + + // Enable Prometheus metrics path commands + OtelCmd.AddCommand(createOTELSetCommand( + "otel-enable-prometheus-metrics-path", "enable-prometheus-metrics-path", + "Prometheus metrics path flag", "true")) + OtelCmd.AddCommand(createOTELGetCommand( + "otel-enable-prometheus-metrics-path", "enable-prometheus-metrics-path", + "Prometheus metrics path flag")) + OtelCmd.AddCommand(createOTELUnsetCommand( + "otel-enable-prometheus-metrics-path", "enable-prometheus-metrics-path", + "Prometheus metrics path flag")) } diff --git a/docs/cli/thv_config_otel.md b/docs/cli/thv_config_otel.md index 24781da35..701fd3d3a 100644 --- a/docs/cli/thv_config_otel.md +++ b/docs/cli/thv_config_otel.md @@ -35,21 +35,21 @@ Configure OpenTelemetry settings for observability and monitoring of MCP servers * [thv config otel get-enable-prometheus-metrics-path](thv_config_otel_get-enable-prometheus-metrics-path.md) - Get the currently configured OpenTelemetry Prometheus metrics path flag * [thv config otel get-endpoint](thv_config_otel_get-endpoint.md) - Get the currently configured OpenTelemetry endpoint * [thv config otel get-env-vars](thv_config_otel_get-env-vars.md) - Get the currently configured OpenTelemetry environment variables -* [thv config otel get-insecure](thv_config_otel_get-insecure.md) - Get the currently configured OpenTelemetry insecure transport flag +* [thv config otel get-insecure](thv_config_otel_get-insecure.md) - Get the currently configured OpenTelemetry insecure connection flag * [thv config otel get-metrics-enabled](thv_config_otel_get-metrics-enabled.md) - Get the currently configured OpenTelemetry metrics export flag * [thv config otel get-sampling-rate](thv_config_otel_get-sampling-rate.md) - Get the currently configured OpenTelemetry sampling rate * [thv config otel get-tracing-enabled](thv_config_otel_get-tracing-enabled.md) - Get the currently configured OpenTelemetry tracing export flag * [thv config otel set-enable-prometheus-metrics-path](thv_config_otel_set-enable-prometheus-metrics-path.md) - Set the OpenTelemetry Prometheus metrics path flag * [thv config otel set-endpoint](thv_config_otel_set-endpoint.md) - Set the OpenTelemetry endpoint URL * [thv config otel set-env-vars](thv_config_otel_set-env-vars.md) - Set the OpenTelemetry environment variables -* [thv config otel set-insecure](thv_config_otel_set-insecure.md) - Set the OpenTelemetry insecure transport flag -* [thv config otel set-metrics-enabled](thv_config_otel_set-metrics-enabled.md) - Set the OpenTelemetry metrics export to enabled +* [thv config otel set-insecure](thv_config_otel_set-insecure.md) - Set the OpenTelemetry insecure connection flag +* [thv config otel set-metrics-enabled](thv_config_otel_set-metrics-enabled.md) - Set the OpenTelemetry metrics export flag * [thv config otel set-sampling-rate](thv_config_otel_set-sampling-rate.md) - Set the OpenTelemetry sampling rate -* [thv config otel set-tracing-enabled](thv_config_otel_set-tracing-enabled.md) - Set the OpenTelemetry tracing export to enabled +* [thv config otel set-tracing-enabled](thv_config_otel_set-tracing-enabled.md) - Set the OpenTelemetry tracing export flag * [thv config otel unset-enable-prometheus-metrics-path](thv_config_otel_unset-enable-prometheus-metrics-path.md) - Remove the configured OpenTelemetry Prometheus metrics path flag * [thv config otel unset-endpoint](thv_config_otel_unset-endpoint.md) - Remove the configured OpenTelemetry endpoint * [thv config otel unset-env-vars](thv_config_otel_unset-env-vars.md) - Remove the configured OpenTelemetry environment variables -* [thv config otel unset-insecure](thv_config_otel_unset-insecure.md) - Remove the configured OpenTelemetry insecure transport flag +* [thv config otel unset-insecure](thv_config_otel_unset-insecure.md) - Remove the configured OpenTelemetry insecure connection flag * [thv config otel unset-metrics-enabled](thv_config_otel_unset-metrics-enabled.md) - Remove the configured OpenTelemetry metrics export flag * [thv config otel unset-sampling-rate](thv_config_otel_unset-sampling-rate.md) - Remove the configured OpenTelemetry sampling rate * [thv config otel unset-tracing-enabled](thv_config_otel_unset-tracing-enabled.md) - Remove the configured OpenTelemetry tracing export flag diff --git a/docs/cli/thv_config_otel_get-endpoint.md b/docs/cli/thv_config_otel_get-endpoint.md index f3d4f7ad1..6d0d11b2a 100644 --- a/docs/cli/thv_config_otel_get-endpoint.md +++ b/docs/cli/thv_config_otel_get-endpoint.md @@ -15,7 +15,7 @@ Get the currently configured OpenTelemetry endpoint ### Synopsis -Display the OpenTelemetry endpoint URL that is currently configured. +Display the OpenTelemetry endpoint that is currently configured. ``` thv config otel get-endpoint [flags] diff --git a/docs/cli/thv_config_otel_get-env-vars.md b/docs/cli/thv_config_otel_get-env-vars.md index bfb107e3e..34636098a 100644 --- a/docs/cli/thv_config_otel_get-env-vars.md +++ b/docs/cli/thv_config_otel_get-env-vars.md @@ -15,7 +15,7 @@ Get the currently configured OpenTelemetry environment variables ### Synopsis -Display the OpenTelemetry environment variables that are currently configured. +Display the OpenTelemetry environment variables that is currently configured. ``` thv config otel get-env-vars [flags] diff --git a/docs/cli/thv_config_otel_get-insecure.md b/docs/cli/thv_config_otel_get-insecure.md index bea3eea3c..984389631 100644 --- a/docs/cli/thv_config_otel_get-insecure.md +++ b/docs/cli/thv_config_otel_get-insecure.md @@ -11,11 +11,11 @@ mdx: ## thv config otel get-insecure -Get the currently configured OpenTelemetry insecure transport flag +Get the currently configured OpenTelemetry insecure connection flag ### Synopsis -Display the OpenTelemetry insecure transport flag that is currently configured. +Display the OpenTelemetry insecure connection flag that is currently configured. ``` thv config otel get-insecure [flags] diff --git a/docs/cli/thv_config_otel_set-enable-prometheus-metrics-path.md b/docs/cli/thv_config_otel_set-enable-prometheus-metrics-path.md index 189ecf5a9..7f4d61b57 100644 --- a/docs/cli/thv_config_otel_set-enable-prometheus-metrics-path.md +++ b/docs/cli/thv_config_otel_set-enable-prometheus-metrics-path.md @@ -15,12 +15,14 @@ Set the OpenTelemetry Prometheus metrics path flag ### Synopsis -Set the OpenTelemetry Prometheus metrics path flag to enable /metrics endpoint. +Set the OpenTelemetry Prometheus metrics path flag. + +Example: thv config otel set-enable-prometheus-metrics-path true ``` -thv config otel set-enable-prometheus-metrics-path [flags] +thv config otel set-enable-prometheus-metrics-path [flags] ``` ### Options diff --git a/docs/cli/thv_config_otel_set-endpoint.md b/docs/cli/thv_config_otel_set-endpoint.md index 49cf5a4e5..da468599d 100644 --- a/docs/cli/thv_config_otel_set-endpoint.md +++ b/docs/cli/thv_config_otel_set-endpoint.md @@ -15,9 +15,7 @@ Set the OpenTelemetry endpoint URL ### Synopsis -Set the OpenTelemetry OTLP endpoint URL for tracing and metrics. - -This endpoint will be used by default when running MCP servers unless overridden by the --otel-endpoint flag. +Set the OpenTelemetry endpoint URL. Example: diff --git a/docs/cli/thv_config_otel_set-env-vars.md b/docs/cli/thv_config_otel_set-env-vars.md index cc84588a7..fb96a1b9d 100644 --- a/docs/cli/thv_config_otel_set-env-vars.md +++ b/docs/cli/thv_config_otel_set-env-vars.md @@ -15,16 +15,14 @@ Set the OpenTelemetry environment variables ### Synopsis -Set the list of environment variable names to include in OpenTelemetry spans. - -These environment variables will be used by default when running MCP servers unless overridden by the --otel-env-vars flag. +Set the OpenTelemetry environment variables. Example: - thv config otel set-env-vars USER,HOME,PATH + thv config otel set-env-vars VAR1,VAR2,VAR3 ``` -thv config otel set-env-vars [flags] +thv config otel set-env-vars [flags] ``` ### Options diff --git a/docs/cli/thv_config_otel_set-insecure.md b/docs/cli/thv_config_otel_set-insecure.md index d60d698c7..98168e5ae 100644 --- a/docs/cli/thv_config_otel_set-insecure.md +++ b/docs/cli/thv_config_otel_set-insecure.md @@ -11,16 +11,18 @@ mdx: ## thv config otel set-insecure -Set the OpenTelemetry insecure transport flag +Set the OpenTelemetry insecure connection flag ### Synopsis -Set the OpenTelemetry insecure flag to enable HTTP instead of HTTPS for OTLP endpoints. +Set the OpenTelemetry insecure connection flag. + +Example: thv config otel set-insecure true ``` -thv config otel set-insecure [flags] +thv config otel set-insecure [flags] ``` ### Options diff --git a/docs/cli/thv_config_otel_set-metrics-enabled.md b/docs/cli/thv_config_otel_set-metrics-enabled.md index 7fb0305b3..5a7ad168c 100644 --- a/docs/cli/thv_config_otel_set-metrics-enabled.md +++ b/docs/cli/thv_config_otel_set-metrics-enabled.md @@ -11,16 +11,18 @@ mdx: ## thv config otel set-metrics-enabled -Set the OpenTelemetry metrics export to enabled +Set the OpenTelemetry metrics export flag ### Synopsis -Set the OpenTelemetry metrics flag to enable to export metrics to an OTel collector. +Set the OpenTelemetry metrics export flag. + +Example: thv config otel set-metrics-enabled true ``` -thv config otel set-metrics-enabled [flags] +thv config otel set-metrics-enabled [flags] ``` ### Options diff --git a/docs/cli/thv_config_otel_set-sampling-rate.md b/docs/cli/thv_config_otel_set-sampling-rate.md index f8a734711..72dec6adf 100644 --- a/docs/cli/thv_config_otel_set-sampling-rate.md +++ b/docs/cli/thv_config_otel_set-sampling-rate.md @@ -15,16 +15,14 @@ Set the OpenTelemetry sampling rate ### Synopsis -Set the OpenTelemetry trace sampling rate (between 0.0 and 1.0). - -This sampling rate will be used by default when running MCP servers unless overridden by the --otel-sampling-rate flag. +Set the OpenTelemetry sampling rate. Example: - thv config otel set-sampling-rate 0.1 + thv config otel set-sampling-rate 0.5 ``` -thv config otel set-sampling-rate [flags] +thv config otel set-sampling-rate [flags] ``` ### Options diff --git a/docs/cli/thv_config_otel_set-tracing-enabled.md b/docs/cli/thv_config_otel_set-tracing-enabled.md index fb02fbee3..d02d03d25 100644 --- a/docs/cli/thv_config_otel_set-tracing-enabled.md +++ b/docs/cli/thv_config_otel_set-tracing-enabled.md @@ -11,16 +11,18 @@ mdx: ## thv config otel set-tracing-enabled -Set the OpenTelemetry tracing export to enabled +Set the OpenTelemetry tracing export flag ### Synopsis -Set the OpenTelemetry tracing flag to enable to export traces to an OTel collector. +Set the OpenTelemetry tracing export flag. + +Example: thv config otel set-tracing-enabled true ``` -thv config otel set-tracing-enabled [flags] +thv config otel set-tracing-enabled [flags] ``` ### Options diff --git a/docs/cli/thv_config_otel_unset-insecure.md b/docs/cli/thv_config_otel_unset-insecure.md index 3e31d7372..5675ac303 100644 --- a/docs/cli/thv_config_otel_unset-insecure.md +++ b/docs/cli/thv_config_otel_unset-insecure.md @@ -11,11 +11,11 @@ mdx: ## thv config otel unset-insecure -Remove the configured OpenTelemetry insecure transport flag +Remove the configured OpenTelemetry insecure connection flag ### Synopsis -Remove the OpenTelemetry insecure transport flag configuration. +Remove the OpenTelemetry insecure connection flag configuration. ``` thv config otel unset-insecure [flags] diff --git a/pkg/config/doc.go b/pkg/config/doc.go new file mode 100644 index 000000000..6257b96b0 --- /dev/null +++ b/pkg/config/doc.go @@ -0,0 +1,110 @@ +// Package config provides configuration management for ToolHive, including a +// generic framework for easily adding new configuration fields. +// +// # Architecture +// +// The package uses a Provider pattern to abstract configuration storage: +// - DefaultProvider: Uses XDG config directories (~/.config/toolhive/config.yaml) +// - PathProvider: Uses a specific file path (useful for testing) +// - KubernetesProvider: No-op implementation for Kubernetes environments +// +// # Generic Config Field Framework +// +// The framework allows you to define config fields declaratively with minimal +// boilerplate. Fields are registered once with validation, getters, setters, +// and unseters. +// +// # Adding a New Config Field +// +// Step 1: Add your field to the Config struct: +// +// type Config struct { +// // ... existing fields ... +// MyNewField string `yaml:"my_new_field,omitempty"` +// } +// +// Step 2: Register the field using a helper constructor: +// +// func init() { +// // For simple string fields: +// config.RegisterStringField("my-field", +// func(cfg *Config) *string { return &cfg.MyNewField }, +// validateMyField) // Optional validator +// +// // For boolean fields: +// config.RegisterBoolField("my-bool-field", +// func(cfg *Config) *bool { return &cfg.MyBoolField }, +// nil) // nil = no validation +// +// // For float fields: +// config.RegisterFloatField("my-float-field", +// func(cfg *Config) *float64 { return &cfg.MyFloatField }, +// 0.0, // zero value +// validateMyFloat) +// +// // For string slice fields (comma-separated): +// config.RegisterStringSliceField("my-list-field", +// func(cfg *Config) *[]string { return &cfg.MyListField }, +// nil) +// } +// +// Step 3: Use the field through the generic framework: +// +// provider := config.NewDefaultProvider() +// +// // Set a value +// err := config.SetConfigField(provider, "my-field", "some-value") +// +// // Get a value +// value, isSet, err := config.GetConfigField(provider, "my-field") +// +// // Unset a value +// err := config.UnsetConfigField(provider, "my-field") +// +// # Advanced: Custom Field Registration +// +// For fields with complex logic, use RegisterConfigField directly: +// +// config.RegisterConfigField(config.ConfigFieldSpec{ +// Name: "my-complex-field", +// SetValidator: func(_ Provider, value string) error { +// // Custom validation logic +// return nil +// }, +// Setter: func(cfg *Config, value string) { +// // Custom setter logic +// }, +// Getter: func(cfg *Config) string { +// // Custom getter logic +// return "" +// }, +// Unsetter: func(cfg *Config) { +// // Custom unsetter logic +// }, +// }) +// +// # Validation Helpers +// +// The package provides common validation functions: +// - validateFilePath: Validates file exists and returns cleaned path +// - validateFileExists: Checks if file exists +// - validateJSONFile: Validates file is JSON format +// - validateURLScheme: Validates URL scheme (http/https) +// - makeAbsolutePath: Converts relative to absolute path +// +// Use these in your validator function for consistent error messages. +// +// # Built-in Fields +// +// The following fields are currently registered: +// - ca-cert: Path to a CA certificate file for TLS validation +// - registry-url: URL of the MCP server registry (HTTP/HTTPS) +// - registry-file: Path to a local JSON file containing the registry +// - otel-endpoint: OpenTelemetry OTLP endpoint +// - otel-sampling-rate: Trace sampling rate (0.0-1.0) +// - otel-env-vars: Environment variables for telemetry +// - otel-metrics-enabled: Enable metrics export +// - otel-tracing-enabled: Enable tracing export +// - otel-insecure: Use insecure connection +// - otel-enable-prometheus-metrics-path: Enable Prometheus endpoint +package config diff --git a/pkg/config/fields.go b/pkg/config/fields.go new file mode 100644 index 000000000..b396929e8 --- /dev/null +++ b/pkg/config/fields.go @@ -0,0 +1,261 @@ +package config + +import ( + "fmt" + "strconv" + "strings" + "sync" +) + +// ConfigFieldSpec defines the specification for a generic config field. +// It encapsulates all the logic needed to set, get, unset, and validate a config field. +// +//nolint:revive // ConfigFieldSpec is clear and preferred over FieldSpec +type ConfigFieldSpec struct { + // Name is the unique identifier for the field (e.g., "ca-cert", "registry-url") + Name string + + // SetValidator validates the value before setting it. + // Returns an error if the value is invalid. + // This is called before Setter. + SetValidator func(provider Provider, value string) error + + // Setter sets the value in the Config struct. + // It receives the config to modify and the validated value. + Setter func(cfg *Config, value string) + + // Getter retrieves the current value from the Config struct. + // Returns the current value as a string. + Getter func(cfg *Config) string + + // Unsetter clears the field in the Config struct. + // It resets the field to its default/empty state. + Unsetter func(cfg *Config) +} + +// fieldRegistry stores all registered config field specifications +var fieldRegistry = make(map[string]ConfigFieldSpec) + +// registryMutex protects concurrent access to the field registry +var registryMutex sync.RWMutex + +// RegisterConfigField registers a new config field specification. +// This function is typically called during package initialization. +// Panics if a field with the same name is already registered. +func RegisterConfigField(spec ConfigFieldSpec) { + registryMutex.Lock() + defer registryMutex.Unlock() + + if spec.Name == "" { + panic("config field name cannot be empty") + } + + if _, exists := fieldRegistry[spec.Name]; exists { + panic(fmt.Sprintf("config field %q is already registered", spec.Name)) + } + + // Validate required fields + if spec.Setter == nil { + panic(fmt.Sprintf("config field %q must have a Setter", spec.Name)) + } + if spec.Getter == nil { + panic(fmt.Sprintf("config field %q must have a Getter", spec.Name)) + } + if spec.Unsetter == nil { + panic(fmt.Sprintf("config field %q must have an Unsetter", spec.Name)) + } + + fieldRegistry[spec.Name] = spec +} + +// GetConfigFieldSpec retrieves a registered config field specification by name. +// Returns the field spec and true if found, or an empty spec and false if not found. +func GetConfigFieldSpec(fieldName string) (ConfigFieldSpec, bool) { + registryMutex.RLock() + defer registryMutex.RUnlock() + + spec, exists := fieldRegistry[fieldName] + return spec, exists +} + +// ListConfigFields returns a list of all registered config field names. +func ListConfigFields() []string { + registryMutex.RLock() + defer registryMutex.RUnlock() + + fields := make([]string, 0, len(fieldRegistry)) + for name := range fieldRegistry { + fields = append(fields, name) + } + return fields +} + +// SetConfigField sets a config field value using the generic framework. +// It looks up the field spec, validates the value, and updates the config. +// Returns an error if the field is not registered, validation fails, or update fails. +func SetConfigField(provider Provider, fieldName, value string) error { + spec, exists := GetConfigFieldSpec(fieldName) + if !exists { + return fmt.Errorf("unknown config field: %q", fieldName) + } + + // Run custom validation if provided + if spec.SetValidator != nil { + if err := spec.SetValidator(provider, value); err != nil { + return err + } + } + + // Update the config + err := provider.UpdateConfig(func(cfg *Config) { + spec.Setter(cfg, value) + }) + if err != nil { + return fmt.Errorf("failed to update configuration: %w", err) + } + + return nil +} + +// GetConfigField retrieves a config field value using the generic framework. +// It looks up the field spec and returns the current value. +// Returns the value, whether it's set (non-empty), and any error. +func GetConfigField(provider Provider, fieldName string) (value string, isSet bool, err error) { + spec, exists := GetConfigFieldSpec(fieldName) + if !exists { + return "", false, fmt.Errorf("unknown config field: %q", fieldName) + } + + cfg := provider.GetConfig() + value = spec.Getter(cfg) + isSet = value != "" + + return value, isSet, nil +} + +// UnsetConfigField clears a config field using the generic framework. +// It looks up the field spec and resets the field to its default state. +// Returns an error if the field is not registered or update fails. +func UnsetConfigField(provider Provider, fieldName string) error { + spec, exists := GetConfigFieldSpec(fieldName) + if !exists { + return fmt.Errorf("unknown config field: %q", fieldName) + } + + // Update the config + err := provider.UpdateConfig(func(cfg *Config) { + spec.Unsetter(cfg) + }) + if err != nil { + return fmt.Errorf("failed to update configuration: %w", err) + } + + return nil +} + +// Helper constructors for common field types + +// RegisterStringField registers a simple string config field with optional validation. +// The fieldGetter returns a pointer to the string field in the config struct. +func RegisterStringField( + name string, + fieldGetter func(*Config) *string, + validator func(Provider, string) error, +) { + RegisterConfigField(ConfigFieldSpec{ + Name: name, + SetValidator: validator, + Setter: func(cfg *Config, value string) { + *fieldGetter(cfg) = value + }, + Getter: func(cfg *Config) string { + return *fieldGetter(cfg) + }, + Unsetter: func(cfg *Config) { + *fieldGetter(cfg) = "" + }, + }) +} + +// RegisterBoolField registers a boolean config field with automatic string conversion. +// The fieldGetter returns a pointer to the bool field in the config struct. +func RegisterBoolField( + name string, + fieldGetter func(*Config) *bool, + validator func(Provider, string) error, +) { + RegisterConfigField(ConfigFieldSpec{ + Name: name, + SetValidator: validator, + Setter: func(cfg *Config, value string) { + enabled, _ := strconv.ParseBool(value) // Already validated + *fieldGetter(cfg) = enabled + }, + Getter: func(cfg *Config) string { + return strconv.FormatBool(*fieldGetter(cfg)) + }, + Unsetter: func(cfg *Config) { + *fieldGetter(cfg) = false + }, + }) +} + +// RegisterFloatField registers a float64 config field with automatic string conversion. +// The fieldGetter returns a pointer to the float64 field in the config struct. +// The zeroValue parameter specifies what value indicates "unset" (typically 0.0). +func RegisterFloatField( + name string, + fieldGetter func(*Config) *float64, + zeroValue float64, + validator func(Provider, string) error, +) { + RegisterConfigField(ConfigFieldSpec{ + Name: name, + SetValidator: validator, + Setter: func(cfg *Config, value string) { + floatVal, _ := strconv.ParseFloat(value, 64) // Already validated + *fieldGetter(cfg) = floatVal + }, + Getter: func(cfg *Config) string { + val := *fieldGetter(cfg) + if val == zeroValue { + return "" + } + return strconv.FormatFloat(val, 'f', -1, 64) + }, + Unsetter: func(cfg *Config) { + *fieldGetter(cfg) = zeroValue + }, + }) +} + +// RegisterStringSliceField registers a string slice config field with comma-separated string conversion. +// The fieldGetter returns a pointer to the []string field in the config struct. +func RegisterStringSliceField( + name string, + fieldGetter func(*Config) *[]string, + validator func(Provider, string) error, +) { + RegisterConfigField(ConfigFieldSpec{ + Name: name, + SetValidator: validator, + Setter: func(cfg *Config, value string) { + vars := strings.Split(value, ",") + // Trim whitespace from each item + for i, item := range vars { + vars[i] = strings.TrimSpace(item) + } + *fieldGetter(cfg) = vars + }, + Getter: func(cfg *Config) string { + slice := *fieldGetter(cfg) + if len(slice) == 0 { + return "" + } + return strings.Join(slice, ",") + }, + Unsetter: func(cfg *Config) { + *fieldGetter(cfg) = nil + }, + }) +} diff --git a/pkg/config/fields_builtin.go b/pkg/config/fields_builtin.go new file mode 100644 index 000000000..6e6934d0a --- /dev/null +++ b/pkg/config/fields_builtin.go @@ -0,0 +1,142 @@ +package config + +import ( + "fmt" + "strings" + + "github.com/stacklok/toolhive/pkg/certs" + "github.com/stacklok/toolhive/pkg/networking" +) + +// init registers all built-in config fields +func init() { + registerCACertField() + registerRegistryURLField() + registerRegistryFileField() +} + +// registerCACertField registers the CA certificate config field +func registerCACertField() { + RegisterStringField("ca-cert", + func(cfg *Config) *string { return &cfg.CACertificatePath }, + func(_ Provider, value string) error { + // Validate and clean the file path + cleanPath, err := validateFilePath(value) + if err != nil { + return fmt.Errorf("CA certificate %w", err) + } + + // Read the certificate + certContent, err := readFile(cleanPath) + if err != nil { + return fmt.Errorf("CA certificate %w", err) + } + + // Validate the certificate format + if err := certs.ValidateCACertificate(certContent); err != nil { + return fmt.Errorf("invalid CA certificate: %w", err) + } + + return nil + }) +} + +// registerRegistryURLField registers the registry URL config field +func registerRegistryURLField() { + RegisterConfigField(ConfigFieldSpec{ + Name: "registry-url", + SetValidator: func(provider Provider, value string) error { + // Parse the URL to extract the allowInsecure flag + // Format: "url" or "url|insecure" for backward compatibility + parts := strings.Split(value, "|") + registryURL := parts[0] + allowInsecure := len(parts) > 1 && parts[1] == "insecure" + + // Validate URL scheme + _, err := validateURLScheme(registryURL, allowInsecure) + if err != nil { + return fmt.Errorf("invalid registry URL: %w", err) + } + + // Check for private IP addresses if not allowed + cfg := provider.GetConfig() + if !cfg.AllowPrivateRegistryIp && !allowInsecure { + registryClient, err := networking.NewHttpClientBuilder().Build() + if err != nil { + return fmt.Errorf("failed to create HTTP client: %w", err) + } + _, err = registryClient.Get(registryURL) + if err != nil && strings.Contains(fmt.Sprint(err), networking.ErrPrivateIpAddress) { + return err + } + } + + return nil + }, + Setter: func(cfg *Config, value string) { + // Parse the value to extract URL and allowInsecure flag + parts := strings.Split(value, "|") + registryURL := parts[0] + allowInsecure := len(parts) > 1 && parts[1] == "insecure" + + cfg.RegistryUrl = registryURL + cfg.LocalRegistryPath = "" // Clear local path when setting URL + if allowInsecure { + cfg.AllowPrivateRegistryIp = true + } + }, + Getter: func(cfg *Config) string { + if cfg.RegistryUrl == "" { + return "" + } + // Return URL with insecure flag if set + if cfg.AllowPrivateRegistryIp { + return cfg.RegistryUrl + "|insecure" + } + return cfg.RegistryUrl + }, + Unsetter: func(cfg *Config) { + cfg.RegistryUrl = "" + cfg.AllowPrivateRegistryIp = false + }, + }) +} + +// registerRegistryFileField registers the registry file config field +func registerRegistryFileField() { + RegisterConfigField(ConfigFieldSpec{ + Name: "registry-file", + SetValidator: func(_ Provider, value string) error { + // Validate file path exists + cleanPath, err := validateFilePath(value) + if err != nil { + return fmt.Errorf("local registry %w", err) + } + + // Validate JSON file + if err := validateJSONFile(cleanPath); err != nil { + return fmt.Errorf("registry file: %w", err) + } + + return nil + }, + Setter: func(cfg *Config, value string) { + // Clean and make absolute + cleanPath, _ := validateFilePath(value) + absPath, err := makeAbsolutePath(cleanPath) + if err != nil { + // Fall back to cleaned path if absolute path resolution fails + absPath = cleanPath + } + + cfg.LocalRegistryPath = absPath + cfg.RegistryUrl = "" // Clear URL when setting local path + }, + Getter: func(cfg *Config) string { + return cfg.LocalRegistryPath + }, + Unsetter: func(cfg *Config) { + cfg.LocalRegistryPath = "" + }, + }) +} diff --git a/pkg/config/fields_builtin_test.go b/pkg/config/fields_builtin_test.go new file mode 100644 index 000000000..00e707a3c --- /dev/null +++ b/pkg/config/fields_builtin_test.go @@ -0,0 +1,378 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/logger" +) + +const ( + validCertType = "valid" + invalidCertType = "invalid" + nonexistentCertType = "nonexistent" + validFileType = "valid" + invalidJSONFileType = "invalid-json" + nonJSONFileType = "non-json" +) + +func TestBuiltinFields_CACert(t *testing.T) { + t.Parallel() + logger.Initialize() + + tests := []struct { + name string + certType string // validCertType, invalidCertType, nonexistentCertType + wantErr bool + errContains string + }{ + { + name: "valid CA certificate", + certType: validCertType, + wantErr: false, + }, + { + name: "non-existent certificate", + certType: nonexistentCertType, + wantErr: true, + errContains: "file not found", + }, + { + name: "invalid certificate format", + certType: invalidCertType, + wantErr: true, + errContains: "invalid CA certificate", + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Create test files for each subtest + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + // Create a valid CA certificate for testing + certPath := filepath.Join(tempDir, "test-cert.pem") + err := os.WriteFile(certPath, []byte(validCACertificate), 0600) + require.NoError(t, err) + + // Create an invalid certificate + invalidCertPath := filepath.Join(tempDir, "invalid-cert.pem") + err = os.WriteFile(invalidCertPath, []byte("not a valid certificate"), 0600) + require.NoError(t, err) + + provider := NewPathProvider(configPath) + + // Determine which cert path to use based on test type + var testCertPath string + switch tt.certType { + case validCertType: + testCertPath = certPath + case invalidCertType: + testCertPath = invalidCertPath + case nonexistentCertType: + testCertPath = "/non/existent/cert.pem" + } + + err = SetConfigField(provider, "ca-cert", testCertPath) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + + // Verify the field was set correctly + value, isSet, err := GetConfigField(provider, "ca-cert") + require.NoError(t, err) + assert.True(t, isSet) + assert.Equal(t, testCertPath, value) + + // Test unset + err = UnsetConfigField(provider, "ca-cert") + require.NoError(t, err) + + value, isSet, err = GetConfigField(provider, "ca-cert") + require.NoError(t, err) + assert.False(t, isSet) + assert.Empty(t, value) + } + }) + } +} + +func TestBuiltinFields_RegistryFile(t *testing.T) { + t.Parallel() + logger.Initialize() + + tests := []struct { + name string + fileType string // validFileType, invalidJSONFileType, nonJSONFileType, nonexistentCertType + wantErr bool + errContains string + }{ + { + name: "valid registry file", + fileType: validFileType, + wantErr: false, + }, + { + name: "non-existent file", + fileType: nonexistentCertType, + wantErr: true, + errContains: "file not found", + }, + { + name: "invalid JSON content", + fileType: invalidJSONFileType, + wantErr: true, + errContains: "invalid JSON format", + }, + { + name: "non-JSON file extension", + fileType: nonJSONFileType, + wantErr: true, + errContains: "must be a JSON file", + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Create test files for each subtest + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + // Create a valid JSON registry file + validRegistryPath := filepath.Join(tempDir, "registry.json") + validJSON := `{"servers": [{"name": "test", "url": "https://example.com"}]}` + err := os.WriteFile(validRegistryPath, []byte(validJSON), 0600) + require.NoError(t, err) + + // Create an invalid JSON file + invalidJSONPath := filepath.Join(tempDir, "invalid.json") + err = os.WriteFile(invalidJSONPath, []byte("not valid json"), 0600) + require.NoError(t, err) + + // Create a non-JSON file + nonJSONPath := filepath.Join(tempDir, "registry.txt") + err = os.WriteFile(nonJSONPath, []byte("some text"), 0600) + require.NoError(t, err) + + provider := NewPathProvider(configPath) + + // Determine which registry path to use based on file type + var testRegistryPath string + switch tt.fileType { + case validFileType: + testRegistryPath = validRegistryPath + case invalidJSONFileType: + testRegistryPath = invalidJSONPath + case nonJSONFileType: + testRegistryPath = nonJSONPath + case nonexistentCertType: + testRegistryPath = "/non/existent/registry.json" + } + + err = SetConfigField(provider, "registry-file", testRegistryPath) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + + // Verify the field was set correctly (should be absolute path) + value, isSet, err := GetConfigField(provider, "registry-file") + require.NoError(t, err) + assert.True(t, isSet) + assert.True(t, filepath.IsAbs(value), "path should be absolute") + + // The value might be cleaned/absolute, so check the base name + assert.Equal(t, filepath.Base(testRegistryPath), filepath.Base(value)) + + // Verify URL is cleared when file is set + cfg := provider.GetConfig() + assert.Empty(t, cfg.RegistryUrl, "registry URL should be cleared when file is set") + + // Test unset + err = UnsetConfigField(provider, "registry-file") + require.NoError(t, err) + + value, isSet, err = GetConfigField(provider, "registry-file") + require.NoError(t, err) + assert.False(t, isSet) + assert.Empty(t, value) + } + }) + } +} + +func TestBuiltinFields_RegistryURL(t *testing.T) { + t.Parallel() + logger.Initialize() + + tests := []struct { + name string + registryURL string + wantErr bool + errContains string + expectInsecure bool + expectedStored string + expectPrivateSet bool + }{ + { + name: "valid HTTPS URL", + registryURL: "https://registry.example.com/servers", + wantErr: false, + expectedStored: "https://registry.example.com/servers", + }, + { + name: "HTTP URL with insecure flag", + registryURL: "http://registry.example.com/servers|insecure", + wantErr: false, + expectInsecure: true, + expectedStored: "http://registry.example.com/servers|insecure", + expectPrivateSet: true, + }, + { + name: "invalid URL format", + registryURL: "not-a-url", + wantErr: true, + errContains: "invalid registry URL", + }, + { + name: "HTTP without insecure flag", + registryURL: "http://registry.example.com/servers", + wantErr: true, + errContains: "invalid registry URL", + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Create a separate config file for each test case to avoid shared state + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + provider := NewPathProvider(configPath) + + err := SetConfigField(provider, "registry-url", tt.registryURL) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + + // Verify the field was set correctly + value, isSet, err := GetConfigField(provider, "registry-url") + require.NoError(t, err) + assert.True(t, isSet) + assert.Equal(t, tt.expectedStored, value) + + // Verify file path is cleared when URL is set + cfg := provider.GetConfig() + assert.Empty(t, cfg.LocalRegistryPath, "registry file should be cleared when URL is set") + + // Verify AllowPrivateRegistryIp is set correctly + if tt.expectPrivateSet { + assert.True(t, cfg.AllowPrivateRegistryIp, "AllowPrivateRegistryIp should be set for insecure URLs") + } + + // Test unset + err = UnsetConfigField(provider, "registry-url") + require.NoError(t, err) + + value, isSet, err = GetConfigField(provider, "registry-url") + require.NoError(t, err) + assert.False(t, isSet) + assert.Empty(t, value) + + // Verify AllowPrivateRegistryIp is reset + cfg = provider.GetConfig() + assert.False(t, cfg.AllowPrivateRegistryIp, "AllowPrivateRegistryIp should be reset after unset") + } + }) + } +} + +func TestBuiltinFields_MutualExclusivity(t *testing.T) { + t.Parallel() + logger.Initialize() + + // Create a temporary config file + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + // Create a valid JSON registry file + validRegistryPath := filepath.Join(tempDir, "registry.json") + validJSON := `{"servers": []}` + err := os.WriteFile(validRegistryPath, []byte(validJSON), 0600) + require.NoError(t, err) + + provider := NewPathProvider(configPath) + + // Set registry URL + err = SetConfigField(provider, "registry-url", "https://registry.example.com") + require.NoError(t, err) + + // Verify URL is set + cfg := provider.GetConfig() + assert.NotEmpty(t, cfg.RegistryUrl) + assert.Empty(t, cfg.LocalRegistryPath) + + // Set registry file (should clear URL) + err = SetConfigField(provider, "registry-file", validRegistryPath) + require.NoError(t, err) + + // Verify file is set and URL is cleared + cfg = provider.GetConfig() + assert.Empty(t, cfg.RegistryUrl) + assert.NotEmpty(t, cfg.LocalRegistryPath) + + // Set registry URL again (should clear file) + err = SetConfigField(provider, "registry-url", "https://registry2.example.com") + require.NoError(t, err) + + // Verify URL is set and file is cleared + cfg = provider.GetConfig() + assert.NotEmpty(t, cfg.RegistryUrl) + assert.Empty(t, cfg.LocalRegistryPath) +} + +func TestBuiltinFields_AllFieldsRegistered(t *testing.T) { + t.Parallel() + + expectedFields := []string{ + "ca-cert", + "registry-url", + "registry-file", + } + + registeredFields := ListConfigFields() + fieldMap := make(map[string]bool) + for _, field := range registeredFields { + fieldMap[field] = true + } + + for _, expectedField := range expectedFields { + assert.True(t, fieldMap[expectedField], "field %q should be registered", expectedField) + + // Verify the field has all required components + spec, exists := GetConfigFieldSpec(expectedField) + require.True(t, exists, "field %q should be registered", expectedField) + assert.NotEmpty(t, spec.Name, "field %q should have a name", expectedField) + assert.NotNil(t, spec.Setter, "field %q should have a setter", expectedField) + assert.NotNil(t, spec.Getter, "field %q should have a getter", expectedField) + assert.NotNil(t, spec.Unsetter, "field %q should have an unsetter", expectedField) + } +} diff --git a/pkg/config/fields_otel.go b/pkg/config/fields_otel.go new file mode 100644 index 000000000..fc0aafa23 --- /dev/null +++ b/pkg/config/fields_otel.go @@ -0,0 +1,92 @@ +package config + +import ( + "fmt" + "strconv" + "strings" +) + +// init registers all OTEL config fields +func init() { + registerOTELEndpoint() + registerOTELSamplingRate() + registerOTELEnvVars() + registerOTELMetricsEnabled() + registerOTELTracingEnabled() + registerOTELInsecure() + registerOTELEnablePrometheusMetricsPath() +} + +// Validators for OTEL fields + +func validateOTELEndpoint(_ Provider, value string) error { + // The endpoint should not start with http:// or https:// + if value != "" && (strings.HasPrefix(value, "http://") || strings.HasPrefix(value, "https://")) { + return fmt.Errorf("endpoint URL should not start with http:// or https://") + } + return nil +} + +func validateOTELSamplingRate(_ Provider, value string) error { + rate, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("invalid sampling rate format: %w", err) + } + if rate < 0.0 || rate > 1.0 { + return fmt.Errorf("sampling rate must be between 0.0 and 1.0") + } + return nil +} + +func validateBool(_ Provider, value string) error { + _, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid boolean value: %w (expected: true, false, 1, 0)", err) + } + return nil +} + +// Field registrations using helper constructors + +func registerOTELEndpoint() { + RegisterStringField("otel-endpoint", + func(cfg *Config) *string { return &cfg.OTEL.Endpoint }, + validateOTELEndpoint) +} + +func registerOTELSamplingRate() { + RegisterFloatField("otel-sampling-rate", + func(cfg *Config) *float64 { return &cfg.OTEL.SamplingRate }, + 0.0, + validateOTELSamplingRate) +} + +func registerOTELEnvVars() { + RegisterStringSliceField("otel-env-vars", + func(cfg *Config) *[]string { return &cfg.OTEL.EnvVars }, + nil) // No validation needed +} + +func registerOTELMetricsEnabled() { + RegisterBoolField("otel-metrics-enabled", + func(cfg *Config) *bool { return &cfg.OTEL.MetricsEnabled }, + validateBool) +} + +func registerOTELTracingEnabled() { + RegisterBoolField("otel-tracing-enabled", + func(cfg *Config) *bool { return &cfg.OTEL.TracingEnabled }, + validateBool) +} + +func registerOTELInsecure() { + RegisterBoolField("otel-insecure", + func(cfg *Config) *bool { return &cfg.OTEL.Insecure }, + validateBool) +} + +func registerOTELEnablePrometheusMetricsPath() { + RegisterBoolField("otel-enable-prometheus-metrics-path", + func(cfg *Config) *bool { return &cfg.OTEL.EnablePrometheusMetricsPath }, + validateBool) +} diff --git a/pkg/config/fields_test.go b/pkg/config/fields_test.go new file mode 100644 index 000000000..04e553eb9 --- /dev/null +++ b/pkg/config/fields_test.go @@ -0,0 +1,582 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// testProvider is a simple test implementation of Provider +type testProvider struct { + config *Config + updateError error +} + +func (p *testProvider) GetConfig() *Config { + if p.config == nil { + cfg := createNewConfigWithDefaults() + p.config = &cfg + } + return p.config +} + +func (p *testProvider) UpdateConfig(updateFn func(*Config)) error { + if p.updateError != nil { + return p.updateError + } + cfg := p.GetConfig() + updateFn(cfg) + return nil +} + +func (p *testProvider) LoadOrCreateConfig() (*Config, error) { + return p.GetConfig(), nil +} + +func (p *testProvider) SetRegistryURL(registryURL string, allowPrivateRegistryIp bool) error { + return setRegistryURL(p, registryURL, allowPrivateRegistryIp) +} + +func (p *testProvider) SetRegistryFile(registryPath string) error { + return setRegistryFile(p, registryPath) +} + +func (p *testProvider) UnsetRegistry() error { + return unsetRegistry(p) +} + +func (p *testProvider) GetRegistryConfig() (url, localPath string, allowPrivateIP bool, registryType string) { + return getRegistryConfig(p) +} + +func (p *testProvider) SetCACert(certPath string) error { + return setCACert(p, certPath) +} + +func (p *testProvider) GetCACert() (certPath string, exists bool, accessible bool) { + return getCACert(p) +} + +func (p *testProvider) UnsetCACert() error { + return unsetCACert(p) +} + +func TestRegisterConfigField(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + spec ConfigFieldSpec + shouldPanic bool + panicMsg string + }{ + { + name: "successful registration", + spec: ConfigFieldSpec{ + Name: "test-field-" + t.Name(), + Setter: func(cfg *Config, value string) { + cfg.CACertificatePath = value + }, + Getter: func(cfg *Config) string { + return cfg.CACertificatePath + }, + Unsetter: func(cfg *Config) { + cfg.CACertificatePath = "" + }, + }, + shouldPanic: false, + }, + { + name: "empty field name", + spec: ConfigFieldSpec{ + Name: "", + Setter: func(cfg *Config, value string) { + cfg.CACertificatePath = value + }, + Getter: func(cfg *Config) string { + return cfg.CACertificatePath + }, + Unsetter: func(cfg *Config) { + cfg.CACertificatePath = "" + }, + }, + shouldPanic: true, + panicMsg: "config field name cannot be empty", + }, + { + name: "missing setter", + spec: ConfigFieldSpec{ + Name: "test-field-no-setter", + Setter: nil, + Getter: func(cfg *Config) string { + return cfg.CACertificatePath + }, + Unsetter: func(cfg *Config) { + cfg.CACertificatePath = "" + }, + }, + shouldPanic: true, + panicMsg: "must have a Setter", + }, + { + name: "missing getter", + spec: ConfigFieldSpec{ + Name: "test-field-no-getter", + Setter: func(cfg *Config, value string) { + cfg.CACertificatePath = value + }, + Getter: nil, + Unsetter: func(cfg *Config) { + cfg.CACertificatePath = "" + }, + }, + shouldPanic: true, + panicMsg: "must have a Getter", + }, + { + name: "missing unsetter", + spec: ConfigFieldSpec{ + Name: "test-field-no-unsetter", + Setter: func(cfg *Config, value string) { + cfg.CACertificatePath = value + }, + Getter: func(cfg *Config) string { + return cfg.CACertificatePath + }, + Unsetter: nil, + }, + shouldPanic: true, + panicMsg: "must have an Unsetter", + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if tt.shouldPanic { + // Capture the panic and check the message + defer func() { + if r := recover(); r != nil { + panicMsg := fmt.Sprint(r) + assert.Contains(t, panicMsg, tt.panicMsg, "panic message should contain expected substring") + } else { + t.Error("expected panic but none occurred") + } + }() + RegisterConfigField(tt.spec) + } else { + assert.NotPanics(t, func() { + RegisterConfigField(tt.spec) + }, "should not panic") + + // Verify field was registered + spec, exists := GetConfigFieldSpec(tt.spec.Name) + assert.True(t, exists, "field should be registered") + assert.Equal(t, tt.spec.Name, spec.Name, "field name should match") + } + }) + } +} + +func TestGetConfigFieldSpec(t *testing.T) { + t.Parallel() + + // Register a test field + testFieldName := "test-get-field-" + t.Name() + RegisterConfigField(ConfigFieldSpec{ + Name: testFieldName, + Setter: func(cfg *Config, value string) { + cfg.CACertificatePath = value + }, + Getter: func(cfg *Config) string { + return cfg.CACertificatePath + }, + Unsetter: func(cfg *Config) { + cfg.CACertificatePath = "" + }, + }) + + tests := []struct { + name string + fieldName string + wantExists bool + }{ + { + name: "existing field", + fieldName: testFieldName, + wantExists: true, + }, + { + name: "non-existent field", + fieldName: "non-existent-field", + wantExists: false, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + spec, exists := GetConfigFieldSpec(tt.fieldName) + assert.Equal(t, tt.wantExists, exists, "existence check should match") + if tt.wantExists { + assert.Equal(t, tt.fieldName, spec.Name, "field name should match") + assert.NotNil(t, spec.Setter, "setter should not be nil") + assert.NotNil(t, spec.Getter, "getter should not be nil") + assert.NotNil(t, spec.Unsetter, "unsetter should not be nil") + } + }) + } +} + +func TestListConfigFields(t *testing.T) { + t.Parallel() + + // Note: This test checks that built-in fields are registered + // The actual list may vary based on init() functions + fields := ListConfigFields() + assert.NotEmpty(t, fields, "should have registered fields") + + // Check for built-in fields + fieldMap := make(map[string]bool) + for _, field := range fields { + fieldMap[field] = true + } + + assert.True(t, fieldMap["ca-cert"], "should have ca-cert field") + assert.True(t, fieldMap["registry-url"], "should have registry-url field") + assert.True(t, fieldMap["registry-file"], "should have registry-file field") +} + +func TestSetConfigField(t *testing.T) { + t.Parallel() + logger.Initialize() + + tests := []struct { + name string + fieldName string + value string + setupProvider func() Provider + wantErr bool + errContains string + }{ + { + name: "non-existent field", + fieldName: "non-existent-field", + value: "test-value", + setupProvider: func() Provider { + return &testProvider{} + }, + wantErr: true, + errContains: "unknown config field", + }, + { + name: "validation failure", + fieldName: "ca-cert", + value: "/non/existent/cert.pem", + setupProvider: func() Provider { + return &testProvider{} + }, + wantErr: true, + errContains: "file not found", + }, + { + name: "update config failure", + fieldName: "test-update-fail-" + t.Name(), + value: "test-value", + setupProvider: func() Provider { + // Register field before returning provider + RegisterConfigField(ConfigFieldSpec{ + Name: "test-update-fail-" + t.Name(), + Setter: func(cfg *Config, value string) { + cfg.CACertificatePath = value + }, + Getter: func(cfg *Config) string { + return cfg.CACertificatePath + }, + Unsetter: func(cfg *Config) { + cfg.CACertificatePath = "" + }, + }) + return &testProvider{updateError: fmt.Errorf("update failed")} + }, + wantErr: true, + errContains: "failed to update configuration", + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + provider := tt.setupProvider() + + err := SetConfigField(provider, tt.fieldName, tt.value) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestGetConfigField(t *testing.T) { + t.Parallel() + logger.Initialize() + + tests := []struct { + name string + fieldName string + setupProvider func() Provider + wantValue string + wantIsSet bool + wantErr bool + errContains string + }{ + { + name: "non-existent field", + fieldName: "non-existent-field", + setupProvider: func() Provider { + return &testProvider{} + }, + wantErr: true, + errContains: "unknown config field", + }, + { + name: "field not set", + fieldName: "ca-cert", + setupProvider: func() Provider { + return &testProvider{ + config: &Config{ + CACertificatePath: "", + }, + } + }, + wantValue: "", + wantIsSet: false, + wantErr: false, + }, + { + name: "field is set", + fieldName: "ca-cert", + setupProvider: func() Provider { + return &testProvider{ + config: &Config{ + CACertificatePath: "/path/to/cert.pem", + }, + } + }, + wantValue: "/path/to/cert.pem", + wantIsSet: true, + wantErr: false, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + provider := tt.setupProvider() + + value, isSet, err := GetConfigField(provider, tt.fieldName) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantValue, value) + assert.Equal(t, tt.wantIsSet, isSet) + } + }) + } +} + +func TestUnsetConfigField(t *testing.T) { + t.Parallel() + logger.Initialize() + + tests := []struct { + name string + fieldName string + setupProvider func() Provider + wantErr bool + errContains string + }{ + { + name: "non-existent field", + fieldName: "non-existent-field", + setupProvider: func() Provider { + return &testProvider{} + }, + wantErr: true, + errContains: "unknown config field", + }, + { + name: "successful unset", + fieldName: "test-unset-success-" + t.Name(), + setupProvider: func() Provider { + // Register field before returning provider + RegisterConfigField(ConfigFieldSpec{ + Name: "test-unset-success-" + t.Name(), + Setter: func(cfg *Config, value string) { + cfg.CACertificatePath = value + }, + Getter: func(cfg *Config) string { + return cfg.CACertificatePath + }, + Unsetter: func(cfg *Config) { + cfg.CACertificatePath = "" + }, + }) + return &testProvider{ + config: &Config{CACertificatePath: "/old/path"}, + } + }, + wantErr: false, + }, + { + name: "update config failure", + fieldName: "test-unset-fail-" + t.Name(), + setupProvider: func() Provider { + // Register field before returning provider + RegisterConfigField(ConfigFieldSpec{ + Name: "test-unset-fail-" + t.Name(), + Setter: func(cfg *Config, value string) { + cfg.CACertificatePath = value + }, + Getter: func(cfg *Config) string { + return cfg.CACertificatePath + }, + Unsetter: func(cfg *Config) { + cfg.CACertificatePath = "" + }, + }) + return &testProvider{updateError: fmt.Errorf("update failed")} + }, + wantErr: true, + errContains: "failed to update configuration", + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + provider := tt.setupProvider() + + err := UnsetConfigField(provider, tt.fieldName) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + // Verify the field was actually unset + if tt.name == "successful unset" { + assert.Equal(t, "", provider.GetConfig().CACertificatePath) + } + } + }) + } +} + +func TestFieldRegistryConcurrency(t *testing.T) { + t.Parallel() + + // Test concurrent reads and writes to the field registry + var wg sync.WaitGroup + numGoroutines := 10 + + // Concurrent registration + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + fieldName := fmt.Sprintf("concurrent-field-%d", idx) + RegisterConfigField(ConfigFieldSpec{ + Name: fieldName, + Setter: func(cfg *Config, value string) { + cfg.CACertificatePath = value + }, + Getter: func(cfg *Config) string { + return cfg.CACertificatePath + }, + Unsetter: func(cfg *Config) { + cfg.CACertificatePath = "" + }, + }) + }(i) + } + + // Concurrent reads + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = ListConfigFields() + }() + } + + wg.Wait() + + // Verify all fields were registered + fields := ListConfigFields() + fieldMap := make(map[string]bool) + for _, field := range fields { + fieldMap[field] = true + } + + for i := 0; i < numGoroutines; i++ { + fieldName := fmt.Sprintf("concurrent-field-%d", i) + assert.True(t, fieldMap[fieldName], "field %s should be registered", fieldName) + } +} + +func TestSetConfigFieldIntegration(t *testing.T) { + t.Parallel() + logger.Initialize() + + // Create a temporary config file + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + // Create a valid CA certificate for testing + certPath := filepath.Join(tempDir, "test-cert.pem") + err := os.WriteFile(certPath, []byte(validCACertificate), 0600) + require.NoError(t, err) + + // Create a provider + provider := NewPathProvider(configPath) + + // Test setting CA certificate + err = SetConfigField(provider, "ca-cert", certPath) + require.NoError(t, err) + + // Verify the field was set + value, isSet, err := GetConfigField(provider, "ca-cert") + require.NoError(t, err) + assert.True(t, isSet) + assert.Equal(t, certPath, value) + + // Test unsetting CA certificate + err = UnsetConfigField(provider, "ca-cert") + require.NoError(t, err) + + // Verify the field was unset + value, isSet, err = GetConfigField(provider, "ca-cert") + require.NoError(t, err) + assert.False(t, isSet) + assert.Empty(t, value) +}