diff --git a/commands/commands.go b/commands/commands.go index ac858550b..e5b07faff 100644 --- a/commands/commands.go +++ b/commands/commands.go @@ -85,6 +85,24 @@ func Run(v string) int { sort.Sort(cli.CommandsByName(commands)) app.Commands = commands app.EnableBashCompletion = true + app.BashComplete = func(cCtx *cli.Context) { + // ripped from cli.DefaultCompleteWithFlags + var lastArg string + + if len(os.Args) > 2 { + lastArg = os.Args[len(os.Args)-2] + } + + if lastArg != "" { + if strings.HasPrefix(lastArg, "-") { + if !islastFlagComplete(lastArg, app.Flags) { + dnscontrolPrintFlagSuggestions(lastArg, app.Flags, cCtx.App.Writer) + return + } + } + } + dnscontrolPrintCommandSuggestions(app.Commands, cCtx.App.Writer) + } if err := app.Run(os.Args); err != nil { return 1 } diff --git a/commands/completion.go b/commands/completion.go index 47d00cacf..448004555 100644 --- a/commands/completion.go +++ b/commands/completion.go @@ -4,11 +4,14 @@ import ( "embed" "errors" "fmt" - "github.com/urfave/cli/v2" + "io" "os" "path" "strings" "text/template" + "unicode/utf8" + + "github.com/urfave/cli/v2" ) //go:embed completion-scripts/completion.*.gotmpl @@ -89,3 +92,79 @@ func getCompletionSupportedShells() (shells []string, shellCompletionScripts map } return shells, shellCompletionScripts, nil } + +func dnscontrolPrintCommandSuggestions(commands []*cli.Command, writer io.Writer) { + for _, command := range commands { + if command.Hidden { + continue + } + if strings.HasSuffix(os.Getenv("SHELL"), "zsh") { + for _, name := range command.Names() { + _, _ = fmt.Fprintf(writer, "%s:%s\n", name, command.Usage) + } + } else { + for _, name := range command.Names() { + _, _ = fmt.Fprintf(writer, "%s\n", name) + } + } + } +} + +func dnscontrolCliArgContains(flagName string) bool { + for _, name := range strings.Split(flagName, ",") { + name = strings.TrimSpace(name) + count := utf8.RuneCountInString(name) + if count > 2 { + count = 2 + } + flag := fmt.Sprintf("%s%s", strings.Repeat("-", count), name) + for _, a := range os.Args { + if a == flag { + return true + } + } + } + return false +} + +func dnscontrolPrintFlagSuggestions(lastArg string, flags []cli.Flag, writer io.Writer) { + cur := strings.TrimPrefix(lastArg, "-") + cur = strings.TrimPrefix(cur, "-") + for _, flag := range flags { + if bflag, ok := flag.(*cli.BoolFlag); ok && bflag.Hidden { + continue + } + for _, name := range flag.Names() { + name = strings.TrimSpace(name) + // this will get total count utf8 letters in flag name + count := utf8.RuneCountInString(name) + if count > 2 { + count = 2 // reuse this count to generate single - or -- in flag completion + } + // if flag name has more than one utf8 letter and last argument in cli has -- prefix then + // skip flag completion for short flags example -v or -x + if strings.HasPrefix(lastArg, "--") && count == 1 { + continue + } + // match if last argument matches this flag and it is not repeated + if strings.HasPrefix(name, cur) && cur != name && !dnscontrolCliArgContains(name) { + flagCompletion := fmt.Sprintf("%s%s", strings.Repeat("-", count), name) + _, _ = fmt.Fprintln(writer, flagCompletion) + } + } + } +} + +func islastFlagComplete(lastArg string, flags []cli.Flag) bool { + cur := strings.TrimPrefix(lastArg, "-") + cur = strings.TrimPrefix(cur, "-") + for _, flag := range flags { + for _, name := range flag.Names() { + name = strings.TrimSpace(name) + if strings.HasPrefix(name, cur) && cur != name && !dnscontrolCliArgContains(name) { + return false + } + } + } + return true +}