diff --git a/pkg/interactive/builder.go b/pkg/interactive/builder.go new file mode 100644 index 00000000..f74689eb --- /dev/null +++ b/pkg/interactive/builder.go @@ -0,0 +1,446 @@ +package interactive + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "slices" + "strconv" +) + +// DefaultMaxDepth bounds recursion to defeat self-referential unions. When the +// guard fires the builder degrades to a raw-JSON prompt. +const DefaultMaxDepth = 12 + +// DefaultParamBagThreshold is the optional-field count above which a struct is +// treated as a parameter bag (the user multi-selects which fields to set). +const DefaultParamBagThreshold = 15 + +// Builder prompts for every field of a struct via reflection. +type Builder struct { + Prompter Prompter + MaxDepth int // 0 uses DefaultMaxDepth + ParamBagThreshold int // 0 uses DefaultParamBagThreshold +} + +func (b *Builder) maxDepth() int { + if b.MaxDepth <= 0 { + return DefaultMaxDepth + } + return b.MaxDepth +} + +func (b *Builder) paramBagThreshold() int { + if b.ParamBagThreshold <= 0 { + return DefaultParamBagThreshold + } + return b.ParamBagThreshold +} + +// Build reflectively prompts for every field of the struct pointed to by v. +// v must be a non-nil pointer to a struct. Fields already non-zero on v are +// preserved (lets the caller pre-populate identifiers). +func (b *Builder) Build(v any) error { + if b.Prompter == nil { + return errors.New("interactive: Builder.Prompter must be set") + } + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() || rv.Elem().Kind() != reflect.Struct { + return errors.New("interactive: Build needs a non-nil pointer to a struct") + } + return b.buildValue(rv.Elem().Type().Name(), rv.Elem(), 0, nil) +} + +func (b *Builder) buildValue(label string, rv reflect.Value, depth int, stack []reflect.Type) error { + if depth > b.maxDepth() || slices.Contains(stack, rv.Type()) { + return b.promptRawJSON(label, rv) + } + + switch rv.Kind() { + case reflect.Pointer: + return b.buildPointer(label, rv, depth, stack) + case reflect.Struct: + return b.buildStruct(label, rv, depth, stack) + case reflect.Slice: + return b.buildSlice(label, rv, depth, stack) + case reflect.Map: + return b.buildMap(label, rv, depth, stack) + case reflect.String, reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Float32, reflect.Float64: + _, err := b.assignScalar(label, rv, false) + return err + case reflect.Interface: + return b.promptRawJSON(label, rv) + default: + return fmt.Errorf("interactive: unsupported kind %s for %s", rv.Kind(), label) + } +} + +func (b *Builder) inputInt(label string, bits int) (int64, error) { + s, err := b.Prompter.Input(label+" (integer)", countValidator()) + if err != nil { + return 0, err + } + if s == "" { + return 0, nil + } + n, err := strconv.ParseInt(s, 10, bits) + if err != nil { + return 0, fmt.Errorf("invalid integer for %s: %w", label, err) + } + return n, nil +} + +// isEnumType reports whether t is a named string type exposing a value-receiver +// IsValid() bool method, the convention every Algolia SDK enum is generated +// with. Such a field is validated against the SDK's own allowed-value check +// rather than a hand-maintained list. +func isEnumType(t reflect.Type) bool { + if t.Kind() != reflect.String || t.Name() == "string" { + return false + } + m, ok := t.MethodByName("IsValid") + if !ok { + return false + } + ft := m.Func.Type() // the receiver counts as the first in-param + return ft.NumIn() == 1 && ft.NumOut() == 1 && ft.Out(0).Kind() == reflect.Bool +} + +// enumValidator validates that s is an allowed value of enum type t by calling +// the type's IsValid() method reflectively. Empty is allowed only when optional. +func enumValidator(t reflect.Type, label string, optional bool) func(string) error { + return func(s string) error { + if s == "" { + if optional { + return nil + } + return fmt.Errorf("%s is required", label) + } + cand := reflect.New(t).Elem() + cand.SetString(s) + if !cand.MethodByName("IsValid").Call(nil)[0].Bool() { + return fmt.Errorf("%q is not a valid %s", s, t.Name()) + } + return nil + } +} + +// assignScalar prompts for a scalar value and writes it into v, which must be +// settable. optional reports whether an empty answer means "skip" (leaving v +// untouched) rather than writing the zero value, and selects the prompt hints. +// It returns whether v was set. Bool keeps a deliberate asymmetry: a required +// bool is a yes/no Confirm, while an optional *bool is free-text true/false with +// an empty answer meaning skip. +func (b *Builder) assignScalar(label string, v reflect.Value, optional bool) (bool, error) { + switch v.Kind() { + case reflect.String: + // SDK enums (named string types with an IsValid() bool method) are + // validated against the SDK's own allowed-value check; plain strings use + // the required/optional rule. + var validate func(string) error + switch { + case isEnumType(v.Type()): + validate = enumValidator(v.Type(), label, optional) + case !optional: + validate = requiredString(label) + } + s, err := b.Prompter.Input(label, validate) + if err != nil { + return false, err + } + if optional && s == "" { + return false, nil + } + v.SetString(s) + return true, nil + case reflect.Bool: + if optional { + s, err := b.Prompter.Input(label+" (true/false, empty to skip)", boolValidator()) + if err != nil || s == "" { + return false, err + } + val, perr := strconv.ParseBool(s) + if perr != nil { + return false, fmt.Errorf("invalid boolean for %s: %w", label, perr) + } + v.SetBool(val) + return true, nil + } + val, err := b.Prompter.Confirm(label) + if err != nil { + return false, err + } + v.SetBool(val) + return true, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + hint := " (integer)" + if optional { + hint = " (integer, empty to skip)" + } + s, err := b.Prompter.Input(label+hint, intValidator(label, v.Type().Bits(), optional)) + if err != nil { + return false, err + } + if s == "" { + // Only reachable when optional: intValidator rejects empty for + // required fields, so the prompter re-prompts or errors before here. + return false, nil + } + n, perr := strconv.ParseInt(s, 10, v.Type().Bits()) + if perr != nil { + return false, fmt.Errorf("invalid integer for %s: %w", label, perr) + } + v.SetInt(n) + return true, nil + case reflect.Float32, reflect.Float64: + hint := " (number)" + if optional { + hint = " (number, empty to skip)" + } + s, err := b.Prompter.Input(label+hint, floatValidator(label, v.Type().Bits(), optional)) + if err != nil { + return false, err + } + if s == "" { + // Only reachable when optional (see the integer case above). + return false, nil + } + // Parse at the field's precision so a float32 overflow is caught here + // rather than later becoming +Inf, which encoding/json cannot marshal. + f, perr := strconv.ParseFloat(s, v.Type().Bits()) + if perr != nil { + return false, fmt.Errorf("invalid number for %s: %w", label, perr) + } + v.SetFloat(f) + return true, nil + default: + return false, fmt.Errorf("interactive: unsupported scalar kind %s for %s", v.Kind(), label) + } +} + +func (b *Builder) buildPointer(label string, rv reflect.Value, depth int, stack []reflect.Type) error { + elem := rv.Type().Elem() + switch elem.Kind() { + case reflect.String, reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Float32, reflect.Float64: + // Optional scalar: build into a fresh element and only assign the pointer + // if the user provided a value. + ptr := reflect.New(elem) + set, err := b.assignScalar(label, ptr.Elem(), true) + if err != nil || !set { + return err + } + rv.Set(ptr) + return nil + } + + // Pointer to struct, slice, map, or other composite. + set, err := b.Prompter.Confirm("Set " + label + "?") + if err != nil || !set { + return err + } + ptr := reflect.New(elem) + if err := b.buildValue(label, ptr.Elem(), depth+1, stack); err != nil { + return err + } + rv.Set(ptr) + return nil +} + +func (b *Builder) promptRawJSON(label string, rv reflect.Value) error { + raw, err := b.Prompter.Input(label+" (raw JSON, empty to skip)", jsonValidator()) + if err != nil || raw == "" { + return err + } + if !rv.CanAddr() { + return fmt.Errorf("interactive: cannot unmarshal raw JSON into unaddressable %s", label) + } + if err := json.Unmarshal([]byte(raw), rv.Addr().Interface()); err != nil { + return fmt.Errorf("invalid JSON for %s: %w", label, err) + } + return nil +} + +func (b *Builder) buildStruct(label string, rv reflect.Value, depth int, stack []reflect.Type) error { + t := rv.Type() + if shouldSkipType(t) { + return nil + } + if isUnionType(t) { + return b.buildUnion(label, rv, depth, stack) + } + if isParamBag(t, b.paramBagThreshold()) { + return b.buildParamBag(label, rv, depth, stack) + } + stack = append(stack, t) + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() { + continue + } + jsonTag := f.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + // Preserve fields the caller pre-populated. + if !rv.Field(i).IsZero() { + continue + } + fieldLabel := label + "." + jsonFieldName(jsonTag) + if err := b.buildValue(fieldLabel, rv.Field(i), depth+1, stack); err != nil { + return err + } + } + return nil +} + +func (b *Builder) buildUnion(label string, rv reflect.Value, depth int, stack []reflect.Type) error { + t := rv.Type() + type variant struct { + fieldIdx int + name string + } + var variants []variant + var labels []string + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() || f.Type.Kind() != reflect.Pointer { + continue + } + name := f.Type.Elem().Name() + if name == "" { + name = f.Name + } + variants = append(variants, variant{i, name}) + labels = append(labels, name) + } + if len(variants) == 0 { + return nil + } + pick, err := b.Prompter.Select(label+" (variant)", labels) + if err != nil { + return err + } + for _, v := range variants { + if v.name != pick { + continue + } + field := rv.Field(v.fieldIdx) + ptr := reflect.New(field.Type().Elem()) + stack = append(stack, t) + if err := b.buildValue(label+"."+v.name, ptr.Elem(), depth+1, stack); err != nil { + return err + } + field.Set(ptr) + return nil + } + return nil +} + +func (b *Builder) buildParamBag(label string, rv reflect.Value, depth int, stack []reflect.Type) error { + t := rv.Type() + type entry struct { + fieldIdx int + jsonName string + typeName string + } + var entries []entry + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() { + continue + } + jsonTag := f.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + ft := f.Type + for ft.Kind() == reflect.Pointer { + ft = ft.Elem() + } + typeName := ft.Name() + if typeName == "" { + typeName = ft.Kind().String() + } + entries = append(entries, entry{i, jsonFieldName(jsonTag), typeName}) + } + if len(entries) == 0 { + return nil + } + options := make([]string, len(entries)) + for i, e := range entries { + options[i] = fmt.Sprintf("%s (%s)", e.jsonName, e.typeName) + } + picked, err := b.Prompter.MultiSelect(label, options) + if err != nil { + return err + } + stack = append(stack, t) + for _, idx := range picked { + if idx < 0 || idx >= len(entries) { + continue // defend against a Prompter returning out-of-range indexes + } + e := entries[idx] + fieldLabel := label + "." + e.jsonName + if err := b.buildValue(fieldLabel, rv.Field(e.fieldIdx), depth+1, stack); err != nil { + return err + } + } + return nil +} + +func (b *Builder) buildSlice(label string, rv reflect.Value, depth int, stack []reflect.Type) error { + count, err := b.inputInt("how many "+label, 32) + if err != nil { + return err + } + if count <= 0 { + return nil + } + t := rv.Type() + out := reflect.MakeSlice(t, int(count), int(count)) + for i := 0; i < int(count); i++ { + itemLabel := fmt.Sprintf("%s[%d]", label, i) + if err := b.buildValue(itemLabel, out.Index(i), depth+1, stack); err != nil { + return err + } + } + rv.Set(out) + return nil +} + +func (b *Builder) buildMap(label string, rv reflect.Value, depth int, stack []reflect.Type) error { + t := rv.Type() + if t.Key().Kind() != reflect.String { + return b.promptRawJSON(label, rv) + } + count, err := b.inputInt("how many "+label+" entries", 32) + if err != nil { + return err + } + if count <= 0 { + // Leave the field at its zero value (nil map) so it is omitted, matching + // buildSlice. A confirmed-but-empty map would otherwise serialize as {}. + return nil + } + valType := t.Elem() + out := reflect.MakeMapWithSize(t, int(count)) + for i := 0; i < int(count); i++ { + keyLabel := fmt.Sprintf("%s key[%d]", label, i) + key, err := b.Prompter.Input(keyLabel, requiredString(keyLabel)) + if err != nil { + return err + } + valPtr := reflect.New(valType) + if err := b.buildValue(fmt.Sprintf("%s[%q]", label, key), valPtr.Elem(), depth+1, stack); err != nil { + return err + } + out.SetMapIndex(reflect.ValueOf(key), valPtr.Elem()) + } + rv.Set(out) + return nil +} diff --git a/pkg/interactive/builder_test.go b/pkg/interactive/builder_test.go new file mode 100644 index 00000000..a4a88b17 --- /dev/null +++ b/pkg/interactive/builder_test.go @@ -0,0 +1,318 @@ +package interactive + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type scalars struct { + Name string `json:"name"` + Count int32 `json:"count"` + Ratio float64 `json:"ratio"` + Enabled bool `json:"enabled"` +} + +func TestBuild_Scalars(t *testing.T) { + b := &Builder{Prompter: &ScriptedPrompter{ + Inputs: map[string]string{"name": "widget", "count": "7", "ratio": "1.5"}, + Confirms: map[string]bool{"enabled": true}, + }} + + var v scalars + require.NoError(t, b.Build(&v)) + + assert.Equal(t, "widget", v.Name) + assert.Equal(t, int32(7), v.Count) + assert.Equal(t, 1.5, v.Ratio) + assert.True(t, v.Enabled) +} + +type optionals struct { + Note *string `json:"note,omitempty"` + Max *int32 `json:"max,omitempty"` +} + +func TestBuild_OptionalSkipped(t *testing.T) { + b := &Builder{Prompter: &ScriptedPrompter{}} + + var v optionals + require.NoError(t, b.Build(&v)) + + assert.Nil(t, v.Note) + assert.Nil(t, v.Max) +} + +func TestBuild_OptionalSet(t *testing.T) { + b := &Builder{Prompter: &ScriptedPrompter{ + Inputs: map[string]string{"note": "hi", "max": "9"}, + }} + + var v optionals + require.NoError(t, b.Build(&v)) + + require.NotNil(t, v.Note) + assert.Equal(t, "hi", *v.Note) + require.NotNil(t, v.Max) + assert.Equal(t, int32(9), *v.Max) +} + +type optionalScalars struct { + Flag *bool `json:"flag,omitempty"` + Score *float64 `json:"score,omitempty"` +} + +func TestBuild_OptionalBoolFloatSet(t *testing.T) { + b := &Builder{Prompter: &ScriptedPrompter{ + Inputs: map[string]string{"flag": "true", "score": "2.5"}, + }} + + var v optionalScalars + require.NoError(t, b.Build(&v)) + + require.NotNil(t, v.Flag) + assert.True(t, *v.Flag) + require.NotNil(t, v.Score) + assert.Equal(t, 2.5, *v.Score) +} + +func TestBuild_OptionalBoolFloatSkipped(t *testing.T) { + b := &Builder{Prompter: &ScriptedPrompter{}} + + var v optionalScalars + require.NoError(t, b.Build(&v)) + + assert.Nil(t, v.Flag) + assert.Nil(t, v.Score) +} + +func TestBuild_RequiredEmptyRejected(t *testing.T) { + // A required string with no scripted answer resolves to "" and is rejected + // by the validator (on a real terminal survey would re-prompt). + b := &Builder{Prompter: &ScriptedPrompter{}} + + var v scalars + err := b.Build(&v) + require.Error(t, err) + assert.Contains(t, err.Error(), "is required") +} + +func TestBuild_TypeMismatchRejected(t *testing.T) { + // A non-numeric answer for an integer field is rejected by the validator. + b := &Builder{Prompter: &ScriptedPrompter{ + Inputs: map[string]string{"name": "x", "count": "abc", "ratio": "1"}, + Confirms: map[string]bool{"enabled": true}, + }} + + var v scalars + err := b.Build(&v) + require.Error(t, err) + assert.Contains(t, err.Error(), "whole number") +} + +// shade mimics an SDK enum: a named string type with an IsValid() bool method. +type shade string + +func (s shade) IsValid() bool { return s == "red" || s == "green" || s == "blue" } + +type enumHolder struct { + Shade shade `json:"shade"` +} + +func TestBuild_Enum(t *testing.T) { + // Enums are validated via their IsValid() method (no registry); a valid + // free-text answer is accepted. + b := &Builder{Prompter: &ScriptedPrompter{Inputs: map[string]string{"shade": "green"}}} + + var v enumHolder + require.NoError(t, b.Build(&v)) + assert.Equal(t, shade("green"), v.Shade) +} + +func TestBuild_EnumRejectsInvalid(t *testing.T) { + // An out-of-set value fails IsValid; the validator surfaces the error. + b := &Builder{Prompter: &ScriptedPrompter{Inputs: map[string]string{"shade": "purple"}}} + + var v enumHolder + err := b.Build(&v) + require.Error(t, err) + assert.Contains(t, err.Error(), "not a valid shade") +} + +type address struct { + City string `json:"city"` +} + +type person struct { + ID string `json:"id"` + Address *address `json:"address,omitempty"` + Tags []string `json:"tags,omitempty"` +} + +func TestBuild_PrePopulatedPreserved(t *testing.T) { + // ID is pre-set; builder must not re-prompt it. Only address (confirm:no) + // and tags (count:0) are walked. + b := &Builder{Prompter: &ScriptedPrompter{ + Inputs: map[string]string{"tags": "0"}, + }} + + v := person{ID: "keep-me"} + require.NoError(t, b.Build(&v)) + + assert.Equal(t, "keep-me", v.ID) + assert.Nil(t, v.Address) + assert.Empty(t, v.Tags) +} + +func TestBuild_NestedPointerStruct(t *testing.T) { + b := &Builder{Prompter: &ScriptedPrompter{ + Inputs: map[string]string{"id": "id1", "city": "Berlin", "tags": "0"}, + Confirms: map[string]bool{"address": true}, + }} + + var v person + require.NoError(t, b.Build(&v)) + + assert.Equal(t, "id1", v.ID) + require.NotNil(t, v.Address) + assert.Equal(t, "Berlin", v.Address.City) +} + +func TestBuild_Slice(t *testing.T) { + b := &Builder{Prompter: &ScriptedPrompter{ + Inputs: map[string]string{"id": "id1", "tags": "2", "tags[0]": "a", "tags[1]": "b"}, + }} + + var v person + // address pointer: confirm defaults to false in the fake, so it stays nil. + require.NoError(t, b.Build(&v)) + assert.Equal(t, []string{"a", "b"}, v.Tags) +} + +type union struct { + AsString *string `json:"-"` + AsNumber *int32 `json:"-"` +} + +func TestBuild_Union(t *testing.T) { + b := &Builder{Prompter: &ScriptedPrompter{ + Selects: map[string]string{"variant": "string"}, + Inputs: map[string]string{"union.string": "via-string"}, + }} + + var v union + require.NoError(t, b.Build(&v)) + require.NotNil(t, v.AsString) + assert.Equal(t, "via-string", *v.AsString) + assert.Nil(t, v.AsNumber) +} + +type stringMapHolder struct { + Labels map[string]string `json:"labels,omitempty"` +} + +func TestBuild_StringMap(t *testing.T) { + // Labels is a non-pointer map reached directly (no "Set?" confirm): the + // engine prompts a count then that many key/value pairs. + b := &Builder{Prompter: &ScriptedPrompter{ + Inputs: map[string]string{ + "entries": "2", + "key[0]": "k1", + "key[1]": "k2", + `["k1"]`: "v1", + `["k2"]`: "v2", + }, + }} + + var v stringMapHolder + require.NoError(t, b.Build(&v)) + assert.Equal(t, map[string]string{"k1": "v1", "k2": "v2"}, v.Labels) +} + +type paramBag struct { + Alpha *string `json:"alpha,omitempty"` + Beta *string `json:"beta,omitempty"` + Gamma *string `json:"gamma,omitempty"` +} + +func TestBuild_ParamBag(t *testing.T) { + // threshold 2 -> 3 optional fields qualifies. Select indexes 0 and 2. + b := &Builder{ + Prompter: &ScriptedPrompter{ + MultiSelects: map[string][]int{"paramBag": {0, 2}}, + Inputs: map[string]string{"alpha": "a-val", "gamma": "g-val"}, + }, + ParamBagThreshold: 2, + } + + var v paramBag + require.NoError(t, b.Build(&v)) + require.NotNil(t, v.Alpha) + assert.Equal(t, "a-val", *v.Alpha) + assert.Nil(t, v.Beta) + require.NotNil(t, v.Gamma) + assert.Equal(t, "g-val", *v.Gamma) +} + +type errPrompter struct{ err error } + +func (e errPrompter) Input(string, func(string) error) (string, error) { return "", e.err } +func (e errPrompter) Confirm(string) (bool, error) { return false, e.err } +func (e errPrompter) Select(string, []string) (string, error) { return "", e.err } +func (e errPrompter) MultiSelect(string, []string) ([]int, error) { return nil, e.err } + +func TestBuild_PrompterErrorPropagates(t *testing.T) { + sentinel := errors.New("prompt failed") + b := &Builder{Prompter: errPrompter{err: sentinel}} + + var v scalars + err := b.Build(&v) + require.Error(t, err) + assert.ErrorIs(t, err, sentinel) +} + +type recursive struct { + Name string `json:"name"` + Self *recursive `json:"self,omitempty"` +} + +func TestBuild_CycleGuardFallsBackToRawJSON(t *testing.T) { + // The Self pointer recurses into the same type. Confirming "yes" enters the + // pointer, and at that point the type is already on the recursion stack, so + // the cycle guard fires and degrades to a raw-JSON prompt instead of looping + // forever. The empty raw-JSON input leaves the nested value zero. The test + // completing at all proves the guard stopped the recursion; we additionally + // check the recursion is bounded one level deep (Self.Self stays nil). + b := &Builder{Prompter: &ScriptedPrompter{ + Inputs: map[string]string{"name": "top"}, + Confirms: map[string]bool{"self": true}, + }} + + var v recursive + require.NoError(t, b.Build(&v)) + assert.Equal(t, "top", v.Name) + // Confirm:yes allocates Self; the cycle guard then skips its contents, so + // Self is a non-nil empty leaf and recursion did not continue. + require.NotNil(t, v.Self) + assert.Equal(t, "", v.Self.Name) + assert.Nil(t, v.Self.Self) +} + +func TestBuild_MaxDepthFallsBackToRawJSON(t *testing.T) { + // With MaxDepth 1, descending into the nested address struct (depth 2) + // exceeds the limit and degrades to a raw-JSON prompt. Empty input leaves + // the address contents zero. The build still completes without error. + b := &Builder{ + Prompter: &ScriptedPrompter{ + Inputs: map[string]string{"id": "id1"}, + Confirms: map[string]bool{"address": true}, + }, + MaxDepth: 1, + } + + var v person + require.NoError(t, b.Build(&v)) + assert.Equal(t, "id1", v.ID) +} diff --git a/pkg/interactive/classify.go b/pkg/interactive/classify.go new file mode 100644 index 00000000..5c1d71dd --- /dev/null +++ b/pkg/interactive/classify.go @@ -0,0 +1,80 @@ +package interactive + +import ( + "reflect" + "strings" +) + +// isUnionType reports whether t looks like an OpenAPI oneOf wrapper: a struct +// whose every exported field is a pointer with no JSON name (either untagged, +// as the Algolia SDK emits, or tagged json:"-"). +func isUnionType(t reflect.Type) bool { + if t.Kind() != reflect.Struct || t.NumField() == 0 { + return false + } + exported := 0 + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() { + continue + } + exported++ + jsonTag := f.Tag.Get("json") + if f.Type.Kind() != reflect.Pointer || (jsonTag != "" && jsonTag != "-") { + return false + } + } + return exported > 0 +} + +// isParamBag reports whether t is a large optional-only parameter object: more +// than threshold optional exported fields and zero required fields. +func isParamBag(t reflect.Type, threshold int) bool { + if t.Kind() != reflect.Struct { + return false + } + exported := 0 + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() { + continue + } + jsonTag := f.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + exported++ + if isRequired(f) { + return false + } + } + return exported > threshold +} + +// isRequired reports whether a struct field is required: a non-pointer with a +// json tag that does not contain "omitempty". +func isRequired(f reflect.StructField) bool { + if f.Type.Kind() == reflect.Pointer { + return false + } + tag := f.Tag.Get("json") + if tag == "" || tag == "-" { + return false + } + return !strings.Contains(tag, "omitempty") +} + +// shouldSkipType filters out the SDK's internal /utils helper types. +func shouldSkipType(t reflect.Type) bool { + return strings.Contains(t.PkgPath(), "/utils") +} + +// jsonFieldName returns the field name from a json struct tag, dropping options +// like ",omitempty". +func jsonFieldName(tag string) string { + parts := strings.Split(tag, ",") + if parts[0] != "" { + return parts[0] + } + return tag +} diff --git a/pkg/interactive/classify_test.go b/pkg/interactive/classify_test.go new file mode 100644 index 00000000..c62dd6db --- /dev/null +++ b/pkg/interactive/classify_test.go @@ -0,0 +1,49 @@ +package interactive + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +type unionSample struct { + AsString *string + AsInt *int32 +} + +type plainSample struct { + Name string `json:"name"` + Age *int32 `json:"age,omitempty"` +} + +type paramBagSample struct { + A *string `json:"a,omitempty"` + B *string `json:"b,omitempty"` + C *string `json:"c,omitempty"` +} + +func TestIsUnionType(t *testing.T) { + assert.True(t, isUnionType(reflect.TypeOf(unionSample{}))) + assert.False(t, isUnionType(reflect.TypeOf(plainSample{}))) +} + +func TestIsParamBag(t *testing.T) { + // threshold of 2 makes the 3-field struct a param bag; the plain struct has + // a required field so it never qualifies. + assert.True(t, isParamBag(reflect.TypeOf(paramBagSample{}), 2)) + assert.False(t, isParamBag(reflect.TypeOf(plainSample{}), 2)) +} + +func TestJSONFieldName(t *testing.T) { + assert.Equal(t, "name", jsonFieldName("name,omitempty")) + assert.Equal(t, "age", jsonFieldName("age")) +} + +func TestIsRequired(t *testing.T) { + tp := reflect.TypeOf(plainSample{}) + nameField, _ := tp.FieldByName("Name") + ageField, _ := tp.FieldByName("Age") + assert.True(t, isRequired(nameField)) + assert.False(t, isRequired(ageField)) +} diff --git a/pkg/interactive/prompter.go b/pkg/interactive/prompter.go new file mode 100644 index 00000000..e2a3f336 --- /dev/null +++ b/pkg/interactive/prompter.go @@ -0,0 +1,106 @@ +// Package interactive builds request bodies by prompting the user for each +// field of a Go struct via reflection. Input is gathered through the Prompter +// interface so the traversal can be unit-tested without a terminal. +package interactive + +import ( + "os" + + "github.com/AlecAivazis/survey/v2" + + "github.com/algolia/cli/pkg/iostreams" + "github.com/algolia/cli/pkg/prompt" +) + +// Prompter is the input surface used by the reflective Builder. Production code +// uses SurveyPrompter; tests use a scripted fake. +type Prompter interface { + // Input reads a free-text line. validate, when non-nil, is run on the entry; + // the implementation re-prompts until it passes (SurveyPrompter) or surfaces + // its error (ScriptedPrompter). Pass nil for no validation. An empty entry + // means "skip" for optional fields and zero for required scalars. + Input(label string, validate func(string) error) (string, error) + // Confirm asks a yes/no question. + Confirm(label string) (bool, error) + // Select asks the user to pick exactly one option and returns the chosen + // label. + Select(label string, options []string) (string, error) + // MultiSelect asks the user to pick zero or more options and returns the + // chosen 0-based indexes. + MultiSelect(label string, options []string) ([]int, error) +} + +// SurveyPrompter implements Prompter using survey/v2 via the repo's pkg/prompt +// wrappers (which are swappable in tests). +type SurveyPrompter struct { + io *iostreams.IOStreams +} + +// NewSurveyPrompter returns a Prompter that reads and writes the given streams. +func NewSurveyPrompter(io *iostreams.IOStreams) *SurveyPrompter { + return &SurveyPrompter{io: io} +} + +func (s *SurveyPrompter) surveyOpts() []survey.AskOpt { + return []survey.AskOpt{ + survey.WithStdio( + fileReader{s.io.In}, + fileWriter{s.io.Out, os.Stdout.Fd()}, + fileWriter{s.io.ErrOut, os.Stderr.Fd()}, + ), + } +} + +func (s *SurveyPrompter) Input(label string, validate func(string) error) (string, error) { + var out string + opts := s.surveyOpts() + if validate != nil { + // survey re-prompts in place (showing the error, with the line editable) + // until the validator passes. + opts = append(opts, survey.WithValidator(func(ans interface{}) error { + str, ok := ans.(string) + if !ok { + return nil + } + return validate(str) + })) + } + err := prompt.SurveyAskOne(&survey.Input{Message: label}, &out, opts...) + return out, err +} + +func (s *SurveyPrompter) Confirm(label string) (bool, error) { + var out bool + err := prompt.SurveyAskOne(&survey.Confirm{Message: label}, &out, s.surveyOpts()...) + return out, err +} + +func (s *SurveyPrompter) Select(label string, options []string) (string, error) { + var out string + err := prompt.SurveyAskOne(&survey.Select{Message: label, Options: options}, &out, s.surveyOpts()...) + return out, err +} + +func (s *SurveyPrompter) MultiSelect(label string, options []string) ([]int, error) { + var out []int + err := prompt.SurveyAskOne(&survey.MultiSelect{Message: label, Options: options}, &out, s.surveyOpts()...) + return out, err +} + +// fileReader/fileWriter adapt iostreams to survey's terminal file interfaces. +// Fd() returns the real stdio descriptor so survey can toggle raw mode on a +// genuine TTY; in tests the prompt vars are stubbed so Fd is never used. +type fileReader struct { + r interface{ Read([]byte) (int, error) } +} + +func (f fileReader) Read(p []byte) (int, error) { return f.r.Read(p) } +func (f fileReader) Fd() uintptr { return os.Stdin.Fd() } + +type fileWriter struct { + w interface{ Write([]byte) (int, error) } + fd uintptr +} + +func (f fileWriter) Write(p []byte) (int, error) { return f.w.Write(p) } +func (f fileWriter) Fd() uintptr { return f.fd } diff --git a/pkg/interactive/prompter_test.go b/pkg/interactive/prompter_test.go new file mode 100644 index 00000000..6269a098 --- /dev/null +++ b/pkg/interactive/prompter_test.go @@ -0,0 +1,52 @@ +package interactive + +import ( + "testing" + + "github.com/AlecAivazis/survey/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/algolia/cli/pkg/iostreams" + "github.com/algolia/cli/pkg/prompt" +) + +func TestSurveyPrompter_UsesPromptVars(t *testing.T) { + io, _, _, _ := iostreams.Test() + p := NewSurveyPrompter(io) + + origAskOne := prompt.SurveyAskOne + t.Cleanup(func() { + prompt.SurveyAskOne = origAskOne + }) + + prompt.SurveyAskOne = func(sp survey.Prompt, response interface{}, _ ...survey.AskOpt) error { + switch sp.(type) { + case *survey.Input: + *(response.(*string)) = "typed" + case *survey.Confirm: + *(response.(*bool)) = true + case *survey.Select: + *(response.(*string)) = "picked" + case *survey.MultiSelect: + *(response.(*[]int)) = []int{1} + } + return nil + } + + in, err := p.Input("x", nil) + require.NoError(t, err) + assert.Equal(t, "typed", in) + + c, err := p.Confirm("x") + require.NoError(t, err) + assert.True(t, c) + + sel, err := p.Select("x", []string{"a", "b"}) + require.NoError(t, err) + assert.Equal(t, "picked", sel) + + multi, err := p.MultiSelect("x", []string{"a", "b"}) + require.NoError(t, err) + assert.Equal(t, []int{1}, multi) +} diff --git a/pkg/interactive/scripted_prompter.go b/pkg/interactive/scripted_prompter.go new file mode 100644 index 00000000..66fe25f6 --- /dev/null +++ b/pkg/interactive/scripted_prompter.go @@ -0,0 +1,83 @@ +package interactive + +import "strings" + +// ScriptedPrompter is a deterministic Prompter for tests. Answers are keyed by a +// substring of the prompt label rather than by call order, so a test does not +// break when the target struct (for example a generated SDK type) adds or +// reorders fields. A prompt whose label matches no key returns a safe default: +// empty input, a false confirm, the first option, or no multi-selection. Unknown +// or newly added fields are therefore simply skipped. +// +// Matching rule: a key matches a label when the label equals the key or contains +// it as a substring. An exact match wins over any substring match; among +// substring matches the longest (most specific) key wins, so the result is +// deterministic regardless of map iteration order. +type ScriptedPrompter struct { + Inputs map[string]string + Confirms map[string]bool + Selects map[string]string + MultiSelects map[string][]int +} + +var _ Prompter = (*ScriptedPrompter)(nil) + +func (p *ScriptedPrompter) Input(label string, validate func(string) error) (string, error) { + v, _ := lookup(p.Inputs, label) + // Run the validator on the scripted answer so invalid scripted data surfaces + // deterministically (no infinite re-prompt loop). In production survey would + // re-prompt the user instead. + if validate != nil { + if err := validate(v); err != nil { + return "", err + } + } + return v, nil +} + +func (p *ScriptedPrompter) Confirm(label string) (bool, error) { + if v, ok := lookup(p.Confirms, label); ok { + return v, nil + } + return false, nil +} + +func (p *ScriptedPrompter) Select(label string, options []string) (string, error) { + if v, ok := lookup(p.Selects, label); ok { + return v, nil + } + if len(options) > 0 { + return options[0], nil + } + return "", nil +} + +func (p *ScriptedPrompter) MultiSelect(label string, options []string) ([]int, error) { + if v, ok := lookup(p.MultiSelects, label); ok { + return v, nil + } + return nil, nil +} + +// lookup returns the value whose key exactly equals label, or failing that the +// value whose key is the longest substring of label. The bool reports a match. +func lookup[T any](m map[string]T, label string) (T, bool) { + var zero T + if m == nil { + return zero, false + } + if v, ok := m[label]; ok { + return v, true + } + var best T + bestLen := -1 + for k, v := range m { + if k == "" || !strings.Contains(label, k) { + continue + } + if len(k) > bestLen { + best, bestLen = v, len(k) + } + } + return best, bestLen >= 0 +} diff --git a/pkg/interactive/scripted_prompter_test.go b/pkg/interactive/scripted_prompter_test.go new file mode 100644 index 00000000..632e19fc --- /dev/null +++ b/pkg/interactive/scripted_prompter_test.go @@ -0,0 +1,72 @@ +package interactive + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScriptedPrompter_ValidatorError(t *testing.T) { + // A non-nil validator that rejects the scripted answer surfaces as an error + // rather than looping (the fake does not re-prompt). + p := &ScriptedPrompter{Inputs: map[string]string{"n": "abc"}} + _, err := p.Input("n", func(s string) error { + if s == "abc" { + return errors.New("bad") + } + return nil + }) + require.Error(t, err) +} + +func TestScriptedPrompter_Matched(t *testing.T) { + p := &ScriptedPrompter{ + Inputs: map[string]string{"name": "widget"}, + Confirms: map[string]bool{"address": true}, + Selects: map[string]string{"variant": "string"}, + MultiSelects: map[string][]int{"paramBag": {0, 2}}, + } + + in, err := p.Input("scalars.name", nil) + require.NoError(t, err) + assert.Equal(t, "widget", in) + + c, err := p.Confirm("Set person.address?") + require.NoError(t, err) + assert.True(t, c) + + sel, err := p.Select("union (variant)", []string{"string", "int32"}) + require.NoError(t, err) + assert.Equal(t, "string", sel) + + ms, err := p.MultiSelect("paramBag", []string{"a", "b", "c"}) + require.NoError(t, err) + assert.Equal(t, []int{0, 2}, ms) +} + +func TestScriptedPrompter_Defaults(t *testing.T) { + // An empty script: every unmatched prompt returns a safe default, so unknown + // or newly added fields are simply skipped. + p := &ScriptedPrompter{} + + in, _ := p.Input("anything", nil) + assert.Equal(t, "", in) + c, _ := p.Confirm("anything") + assert.False(t, c) + sel, _ := p.Select("anything", []string{"first", "second"}) + assert.Equal(t, "first", sel) // first option + ms, _ := p.MultiSelect("anything", []string{"a"}) + assert.Nil(t, ms) +} + +func TestScriptedPrompter_LongestKeyWins(t *testing.T) { + // Both keys are substrings of the element label; the more specific one wins. + p := &ScriptedPrompter{Inputs: map[string]string{"tags": "2", "tags[0]": "a"}} + + got, _ := p.Input("person.tags[0]", nil) + assert.Equal(t, "a", got) + got, _ = p.Input("how many person.tags (integer)", nil) + assert.Equal(t, "2", got) +} diff --git a/pkg/interactive/validators.go b/pkg/interactive/validators.go new file mode 100644 index 00000000..67a5d8be --- /dev/null +++ b/pkg/interactive/validators.go @@ -0,0 +1,103 @@ +package interactive + +import ( + "encoding/json" + "fmt" + "strconv" +) + +// The validator builders below produce func(string) error values that are +// passed to Prompter.Input. With SurveyPrompter these become survey validators, +// so an invalid entry re-prompts in place (with the message shown and the line +// editable) instead of aborting the whole build. The functions are pure, so the +// validation rules can be unit-tested directly without a terminal. + +// requiredString rejects an empty entry. +func requiredString(label string) func(string) error { + return func(s string) error { + if s == "" { + return fmt.Errorf("%s is required", label) + } + return nil + } +} + +// intValidator accepts an integer that fits in bits. An empty entry is allowed +// only when optional (it means skip / zero); otherwise it is required. +func intValidator(label string, bits int, optional bool) func(string) error { + return func(s string) error { + if s == "" { + if optional { + return nil + } + return fmt.Errorf("%s is required", label) + } + if _, err := strconv.ParseInt(s, 10, bits); err != nil { + return fmt.Errorf("must be a whole number") + } + return nil + } +} + +// floatValidator accepts a number that fits in bits (32 or 64), so a float32 +// overflow is rejected at input rather than later becoming +Inf. An empty entry +// is allowed only when optional. +func floatValidator(label string, bits int, optional bool) func(string) error { + return func(s string) error { + if s == "" { + if optional { + return nil + } + return fmt.Errorf("%s is required", label) + } + if _, err := strconv.ParseFloat(s, bits); err != nil { + return fmt.Errorf("must be a number") + } + return nil + } +} + +// countValidator accepts a non-negative whole number for "how many ..." prompts. +// An empty entry is allowed (it means none). +func countValidator() func(string) error { + return func(s string) error { + if s == "" { + return nil + } + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return fmt.Errorf("must be a whole number") + } + if n < 0 { + return fmt.Errorf("must be zero or greater") + } + return nil + } +} + +// boolValidator accepts a parseable boolean. It is used for optional *bool, so +// an empty entry (skip) is allowed. +func boolValidator() func(string) error { + return func(s string) error { + if s == "" { + return nil + } + if _, err := strconv.ParseBool(s); err != nil { + return fmt.Errorf("must be true or false") + } + return nil + } +} + +// jsonValidator accepts valid JSON. An empty entry (skip) is allowed. +func jsonValidator() func(string) error { + return func(s string) error { + if s == "" { + return nil + } + if !json.Valid([]byte(s)) { + return fmt.Errorf("must be valid JSON") + } + return nil + } +} diff --git a/pkg/interactive/validators_test.go b/pkg/interactive/validators_test.go new file mode 100644 index 00000000..431b8729 --- /dev/null +++ b/pkg/interactive/validators_test.go @@ -0,0 +1,79 @@ +package interactive + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequiredString(t *testing.T) { + v := requiredString("name") + require.Error(t, v("")) + require.NoError(t, v("x")) +} + +func TestIntValidator(t *testing.T) { + req := intValidator("count", 32, false) + require.Error(t, req("")) // required: empty rejected + require.Error(t, req("abc")) // not a number + require.Error(t, req("3.5")) // not a whole number + require.NoError(t, req("42")) // ok + + opt := intValidator("max", 32, true) + require.NoError(t, opt("")) // optional: empty allowed + require.Error(t, opt("abc")) // still must parse when present + require.NoError(t, opt("-7")) // ok + + // bit width is enforced. + require.Error(t, intValidator("n", 32, true)("9999999999999")) // overflows int32 +} + +func TestFloatValidator(t *testing.T) { + req := floatValidator("ratio", 64, false) + require.Error(t, req("")) + require.Error(t, req("abc")) + require.NoError(t, req("1.5")) + + opt := floatValidator("ratio", 64, true) + require.NoError(t, opt("")) + require.Error(t, opt("x")) + require.NoError(t, opt("2")) + + // 32-bit overflow is rejected at the field's precision. + require.Error(t, floatValidator("ratio", 32, true)("1e40")) + require.NoError(t, floatValidator("ratio", 64, true)("1e40")) +} + +func TestCountValidator(t *testing.T) { + v := countValidator() + require.NoError(t, v("")) // none + require.NoError(t, v("0")) + require.NoError(t, v("3")) + require.Error(t, v("-1")) // negative rejected + require.Error(t, v("abc")) +} + +func TestBoolValidator(t *testing.T) { + v := boolValidator() + require.NoError(t, v("")) // optional skip + require.NoError(t, v("true")) + require.NoError(t, v("false")) + require.NoError(t, v("1")) + require.Error(t, v("maybe")) +} + +func TestJSONValidator(t *testing.T) { + v := jsonValidator() + require.NoError(t, v("")) // skip + require.NoError(t, v(`{"a":1}`)) + require.NoError(t, v(`[1,2,3]`)) + require.Error(t, v(`{not json}`)) +} + +func TestValidatorMessagesAreUserFacing(t *testing.T) { + // The message survey shows on retry should be terse, not a wrapped Go error. + assert.EqualError(t, intValidator("count", 32, false)("abc"), "must be a whole number") + assert.EqualError(t, boolValidator()("x"), "must be true or false") + assert.EqualError(t, jsonValidator()("{"), "must be valid JSON") +}