diff --git a/provider/decode_test.go b/provider/decode_test.go index c7215bd8..ca3feccd 100644 --- a/provider/decode_test.go +++ b/provider/decode_test.go @@ -5,6 +5,7 @@ import ( "github.com/coder/terraform-provider-coder/provider" "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,15 +19,25 @@ func TestDecode(t *testing.T) { aMap := map[string]interface{}{ "name": "Parameter Name", + "type": "number", "display_name": displayName, "legacy_variable": legacyVariable, "legacy_variable_name": legacyVariableName, + "min": nil, + "validation": []map[string]interface{}{ + { + "min": nil, + "max": 5, + }, + }, } var param provider.Parameter err := mapstructure.Decode(aMap, ¶m) require.NoError(t, err) - require.Equal(t, displayName, param.DisplayName) - require.Equal(t, legacyVariable, param.LegacyVariable) - require.Equal(t, legacyVariableName, param.LegacyVariableName) + assert.Equal(t, displayName, param.DisplayName) + assert.Equal(t, legacyVariable, param.LegacyVariable) + assert.Equal(t, legacyVariableName, param.LegacyVariableName) + assert.Equal(t, (*int)(nil), param.Validation[0].Min) + assert.Equal(t, 5, *param.Validation[0].Max) } diff --git a/provider/parameter.go b/provider/parameter.go index 972950a6..a9bbdc65 100644 --- a/provider/parameter.go +++ b/provider/parameter.go @@ -12,10 +12,12 @@ import ( "strconv" "github.com/google/uuid" + "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" "github.com/mitchellh/mapstructure" + "golang.org/x/xerrors" ) type Option struct { @@ -26,8 +28,8 @@ type Option struct { } type Validation struct { - Min int - Max int + Min *int + Max *int Monotonic string Regex string @@ -62,8 +64,18 @@ func parameterDataSource() *schema.Resource { ReadContext: func(ctx context.Context, rd *schema.ResourceData, i interface{}) diag.Diagnostics { rd.SetId(uuid.NewString()) + fixedValidation, err := fixValidationResourceData(rd.GetRawConfig(), rd.Get("validation")) + if err != nil { + return diag.FromErr(err) + } + + err = rd.Set("validation", fixedValidation) + if err != nil { + return diag.FromErr(err) + } + var parameter Parameter - err := mapstructure.Decode(struct { + err = mapstructure.Decode(struct { Value interface{} Name interface{} DisplayName interface{} @@ -98,7 +110,7 @@ func parameterDataSource() *schema.Resource { }(), Icon: rd.Get("icon"), Option: rd.Get("option"), - Validation: rd.Get("validation"), + Validation: fixedValidation, Optional: func() bool { // This hack allows for checking if the "default" field is present in the .tf file. // If "default" is missing or is "null", then it means that this field is required, @@ -272,17 +284,14 @@ func parameterDataSource() *schema.Resource { Elem: &schema.Resource{ Schema: map[string]*schema.Schema{ "min": { - Type: schema.TypeInt, - Optional: true, - Default: 0, - Description: "The minimum of a number parameter.", - RequiredWith: []string{"validation.0.max"}, + Type: schema.TypeInt, + Optional: true, + Description: "The minimum of a number parameter.", }, "max": { - Type: schema.TypeInt, - Optional: true, - Description: "The maximum of a number parameter.", - RequiredWith: []string{"validation.0.min"}, + Type: schema.TypeInt, + Optional: true, + Description: "The maximum of a number parameter.", }, "monotonic": { Type: schema.TypeString, @@ -325,6 +334,45 @@ func parameterDataSource() *schema.Resource { } } +func fixValidationResourceData(rawConfig cty.Value, validation interface{}) (interface{}, error) { + // Read validation from raw config + rawValidation, ok := rawConfig.AsValueMap()["validation"] + if !ok { + return validation, nil // no validation rules, nothing to fix + } + + rawValidationArr := rawValidation.AsValueSlice() + if len(rawValidationArr) == 0 { + return validation, nil // no validation rules, nothing to fix + } + + rawValidationRule := rawValidationArr[0].AsValueMap() + + // Load validation from resource data + vArr, ok := validation.([]interface{}) + if !ok { + return nil, xerrors.New("validation should be an array") + } + + if len(vArr) == 0 { + return validation, nil // no validation rules, nothing to fix + } + + validationRule, ok := vArr[0].(map[string]interface{}) + if !ok { + return nil, xerrors.New("validation rule should be a map") + } + + // Fix the resource data + if rawValidationRule["min"].IsNull() { + validationRule["min"] = nil + } + if rawValidationRule["max"].IsNull() { + validationRule["max"] = nil + } + return vArr, nil +} + func valueIsType(typ, value string) diag.Diagnostics { switch typ { case "number": @@ -353,10 +401,10 @@ func valueIsType(typ, value string) diag.Diagnostics { func (v *Validation) Valid(typ, value string) error { if typ != "number" { - if v.Min != 0 { + if v.Min != nil { return fmt.Errorf("a min cannot be specified for a %s type", typ) } - if v.Max != 0 { + if v.Max != nil { return fmt.Errorf("a max cannot be specified for a %s type", typ) } } @@ -389,10 +437,10 @@ func (v *Validation) Valid(typ, value string) error { if err != nil { return fmt.Errorf("value %q is not a number", value) } - if num < v.Min { + if v.Min != nil && num < *v.Min { return fmt.Errorf("value %d is less than the minimum %d", num, v.Min) } - if num > v.Max { + if v.Max != nil && num > *v.Max { return fmt.Errorf("value %d is more than the maximum %d", num, v.Max) } if v.Monotonic != "" && v.Monotonic != ValidationMonotonicIncreasing && v.Monotonic != ValidationMonotonicDecreasing { diff --git a/provider/parameter_test.go b/provider/parameter_test.go index c349190b..4b2fe9c9 100644 --- a/provider/parameter_test.go +++ b/provider/parameter_test.go @@ -109,6 +109,30 @@ func TestParameter(t *testing.T) { } } `, + }, { + Name: "NumberValidation_Min", + Config: ` + data "coder_parameter" "region" { + name = "Region" + type = "number" + default = 2 + validation { + min = 1 + } + } + `, + }, { + Name: "NumberValidation_Max", + Config: ` + data "coder_parameter" "region" { + name = "Region" + type = "number" + default = 2 + validation { + max = 9 + } + } + `, }, { Name: "DefaultNotNumber", Config: ` @@ -443,18 +467,18 @@ func TestValueValidatesType(t *testing.T) { Regex, RegexError string Min, - Max int + Max *int Monotonic string Error *regexp.Regexp }{{ Name: "StringWithMin", Type: "string", - Min: 1, + Min: ptrNumber(1), Error: regexp.MustCompile("cannot be specified"), }, { Name: "StringWithMax", Type: "string", - Max: 1, + Max: ptrNumber(1), Error: regexp.MustCompile("cannot be specified"), }, { Name: "NonStringWithRegex", @@ -474,13 +498,13 @@ func TestValueValidatesType(t *testing.T) { Name: "NumberBelowMin", Type: "number", Value: "0", - Min: 1, + Min: ptrNumber(1), Error: regexp.MustCompile("is less than the minimum"), }, { Name: "NumberAboveMax", Type: "number", - Value: "1", - Max: 0, + Value: "2", + Max: ptrNumber(1), Error: regexp.MustCompile("is more than the maximum"), }, { Name: "InvalidBool", @@ -498,23 +522,23 @@ func TestValueValidatesType(t *testing.T) { Name: "InvalidMonotonicity", Type: "number", Value: "1", - Min: 0, - Max: 2, + Min: ptrNumber(0), + Max: ptrNumber(2), Monotonic: "foobar", Error: regexp.MustCompile(`number monotonicity can be either "increasing" or "decreasing"`), }, { Name: "IncreasingMonotonicity", Type: "number", Value: "1", - Min: 0, - Max: 2, + Min: ptrNumber(0), + Max: ptrNumber(2), Monotonic: "increasing", }, { Name: "DecreasingMonotonicity", Type: "number", Value: "1", - Min: 0, - Max: 2, + Min: ptrNumber(0), + Max: ptrNumber(2), Monotonic: "decreasing", }, { Name: "ValidListOfStrings", @@ -550,3 +574,7 @@ func TestValueValidatesType(t *testing.T) { }) } } + +func ptrNumber(i int) *int { + return &i +}