Custom command-line flags with flag.Func

 

One of my favorite things about the recent Go 1.16 release is a small — but very welcome — addition to the flag package: the flag.Func() function. This makes it much easier to define and use custom command-line flags in your application.

For example, if you want to parse a flag like --urls="http://example.com http://example.org" directly into a []*url.URL slice, then previously you had two options. You could either create a custom type to implement the flag.Value interface, or leverage a third-party package like pflag.

But now the flag.Func() function gives you a simple and lightweight alternative. In this short post we're going to take a look at a few examples of how you can use it in your own code.

Parsing custom flag types

To demonstrate how this works, let's start with the example I gave above and create an application which accepts a list of URLs, converts each one to a url.URL type, and then prints them out. Similar to this:

$ go run . --urls="http://example.com http://example.org http://example.net"
2023/11/13 17:25:37 http://example.com
2023/11/13 17:25:37 http://example.org
2023/11/13 17:25:37 http://example.net

To make this work, we'll need to do two things:

  • Split the values in the --urls flag on whitespace to get the individual URLs. The strings.Fields() function is a good fit for this task.
  • Convert the individual URL string values to a url.URL type. We can do this using the url.Parse() function.

We can use those together with flag.Func() like so:

package main

import (
	"flag"
	"log"
	"net/url"
	"strings"
)

func main() {
	// First we need to declare a variable to hold the values from the
	// command-line flags. You should also set any defaults, which will be used
	// if the relevant flag is not provided at runtime.
	var (
		urls []*url.URL // Default of a nil slice.
	)

	// The flag.Func() function takes three parameters: the flag name,
	// descriptive help text, and a function with the signature `func(string)
	// error` which is called to process the string value from the command-line
	// flag at runtime and assign it to the necessary variable. In this case, we
	// use strings.Fields() to split the string based on whitespace, parse each
	// part using url.Parse() and append it to the urls slice that we
	// declared above. It url.Parse() returns an error, then we return this from
	// the function.
	flag.Func("urls", "List of URLs to print", func(flagValue string) error {
		for _, u := range strings.Fields(flagValue) {
			parsedURL, err := url.Parse(u)
			if err != nil {
				return err
			}

			urls = append(urls, parsedURL)
		}
		return nil
	})

	// Importantly, call flag.Parse() to trigger actual parsing of the
	// flags.
	flag.Parse()

	// Print out the URLs, pausing between each iteration.
	for _, u := range urls {
		log.Print(u)
	}
}

If you try to run this application, you should find that the flags are parsed and work just like you would expect. For example:

$ go run . --urls="http://example.com http://example.org http://example.net"
2023/11/13 19:18:16 http://example.com
2023/11/13 19:18:16 http://example.org
2023/11/13 19:18:16 http://example.net

Whereas if you provide an invalid flag value that triggers an error in the flag.Func() function, Go will automatically display the corresponding error message and exit. For example:

go run . --urls="http://example.com http://thisisinvalid%+"
invalid value "http://example.com http://thisisinvalid%+" for flag -urls: parse "http://thisisinvalid%+": invalid URL escape "%+"
Usage of /tmp/go-build520367363/b001/exe/cust:
  -urls value
        List of URLs to print
exit status 2

Default flag values

It's really important to point out that if a flag isn't provided, the corresponding flag.Func() function will not be called at all. This means that you cannot set a default value inside a flag.Func() function, so trying to do something like this won't work:

var foo string
    
flag.Func("example", "Example help text", func(flagValue string) error {
    // DON'T DO THIS! This function won't be called if the flag value is "".
    if flagValue == "" {
        foo = "bar"
        return nil
    }

   ...
})

Instead you need to set the default value for a flag before flag.Func() is called. For example:

foo := "bar"
    
flag.Func("example", "Example help text", func(flagValue string) error {
   ...
})

Validating flag values

The flag.Func() function also opens up some new opportunities for validating input data from command-line flags. For example, let's say that your application has an --environment flag and you want to restrict the possible values to development, staging or production.

To do that, you can implement a flag.Func() function similar to this:

package main

import (
    "errors"
    "flag"
    "fmt"
)

func main() {
    var (
        environment = "development"
    )

    flag.Func("environment", "Operating environment", func(flagValue string) error {
        for _, allowedValue := range []string{"development", "staging", "production"} {
            if flagValue == allowedValue {
                environment = flagValue
                return nil
            }
        }
        return errors.New(`must be one of "development", "staging" or "production"`)
    })

    flag.Parse()

    fmt.Printf("The operating environment is: %s\n", environment)
}

Making reusable helpers

If you find yourself repeating the same code in your flag.Func() functions, or the logic is getting too complex, it's possible to break it out into a reusable helper. For example, we could rewrite the example above to process our --environment flag via a generic enumFlag() function, like so:

package main

import (
    "flag"
    "fmt"
)

func main() {
    var (
        environment string = "development"
    )

    enumFlag(&environment, "environment", []string{"development", "staging", "production"}, "Operating environment")

    flag.Parse()

    fmt.Printf("The operating environment is: %s\n", environment)
}

func enumFlag(target *string, name string, safelist []string, usage string) {
    flag.Func(name, usage, func(flagValue string) error {
        for _, allowedValue := range safelist {
            if flagValue == allowedValue {
                *target = flagValue
                return nil
            }
        }

        return fmt.Errorf("must be one of %v", safelist)
    })
}
If you enjoyed this post...

You might like to check out my other Go tutorials on this site, or if you're after something more structured, my books Let's Go and Let's Go Further cover how to build complete, production-ready web apps and APIS with Go.

Not sure how to structure your Go web application?

My new book guides you through the start-to-finish build of a real world web application in Go — covering topics like how to structure your code, manage dependencies, create dynamic database-driven pages, and how to authenticate and authorize users securely.

Take a look!