Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,31 @@ jobs:

- name: Test
run: go test -v ./...

golden-tests:
runs-on: ubuntu-latest
needs: build
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.22'

- name: Run Golden Tests
run: go test -v -run TestGolden

end-to-end-tests:
runs-on: ubuntu-latest
needs: build
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.22'

- name: Run End-to-End Tests
run: go test -v -run TestEndToEnd
29 changes: 28 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Flags:
transform each item name by removing a prefix or comma separated list of prefixes. Default: ""
-type string
comma-separated list of type names; must be set
-typederrors
if true, errors from enumerrs/ will be errors.Join()-ed for errors.Is(...) to simplify invalid value handling. Default: false
-values
if true, alternative string values method will be generated. Default: false
-yaml
Expand Down Expand Up @@ -70,6 +72,9 @@ When Enumer is applied to a type, it will generate:
the enum conform to the `gopkg.in/yaml.v2.Marshaler` and `gopkg.in/yaml.v2.Unmarshaler` interfaces.
- When the flag `sql` is provided, the methods for implementing the `Scanner` and `Valuer` interfaces.
Useful when storing the enum in a database.
- When the flag `typederrors` is provided, the string conversion functions will return errors wrapped with
`errors.Join()` containing a typed error from the `enumerrs` package. This allows you to use `errors.Is()` to
check for specific enum validation failures.


For example, if we have an enum type called `Pill`,
Expand Down Expand Up @@ -200,7 +205,7 @@ For a module-aware repo with `enumer` in the `go.mod` file, generation can be ca
//go:generate go run github.com/dmarkham/enumer -type=YOURTYPE
```

There are four boolean flags: `json`, `text`, `yaml` and `sql`. You can use any combination of them (i.e. `enumer -type=Pill -json -text`),
There are five boolean flags: `json`, `text`, `yaml`, `sql`, and `typederrors`. You can use any combination of them (i.e. `enumer -type=Pill -json -text -typederrors`),

For enum string representation transformation the `transform` and `trimprefix` flags
were added (i.e. `enumer -type=MyType -json -transform=snake`).
Expand All @@ -215,6 +220,28 @@ If a prefix is provided via the `addprefix` flag, it will be added to the start

The boolean flag `values` will additionally create an alternative string values method `Values() []string` to fullfill the `EnumValues` interface of [ent](https://entgo.io/docs/schema-fields/#enum-fields).

## Typed Error Handling

When using the `typederrors` flag, you can handle enum validation errors specifically using `errors.Is()`:

```go
import (
"errors"
"github.com/dmarkham/enumer/enumerrs"
)

// This will return a typed error that can be checked
pill, err := PillString("InvalidValue")
if err != nil {
if errors.Is(err, enumerrs.ErrValueInvalid) {
// Handle invalid enum value specifically
fmt.Println("Invalid pill value provided")
}
// The error also contains a descriptive message
fmt.Printf("Error: %v\n", err)
}
```

## Inspiring projects

- [Álvaro López Espinosa](https://github.com/alvaroloes/enumer)
Expand Down
17 changes: 14 additions & 3 deletions endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// go command is not available on android

//go:build !android
// +build !android

package main
Expand Down Expand Up @@ -75,6 +76,7 @@ func TestEndToEnd(t *testing.T) {
// Names are known to be ASCII and long enough.
var typeName string
var transformNameMethod string
var useTypedErrors bool

switch name {
case "transform_snake.go":
Expand Down Expand Up @@ -110,18 +112,22 @@ func TestEndToEnd(t *testing.T) {
case "transform_whitespace.go":
typeName = "WhitespaceSeparatedValue"
transformNameMethod = "whitespace"
case "typedErrors.go":
typeName = "TypedErrorsValue"
transformNameMethod = "noop"
useTypedErrors = true
default:
typeName = fmt.Sprintf("%c%s", name[0]+'A'-'a', name[1:len(name)-len(".go")])
transformNameMethod = "noop"
}

stringerCompileAndRun(t, dir, stringer, typeName, name, transformNameMethod)
stringerCompileAndRun(t, dir, stringer, typeName, name, transformNameMethod, useTypedErrors)
}
}

// stringerCompileAndRun runs stringer for the named file and compiles and
// runs the target binary in directory dir. That binary will panic if the String method is incorrect.
func stringerCompileAndRun(t *testing.T, dir, stringer, typeName, fileName, transformNameMethod string) {
func stringerCompileAndRun(t *testing.T, dir, stringer, typeName, fileName, transformNameMethod string, useTypedErrors bool) {
t.Logf("run: %s %s\n", fileName, typeName)
source := filepath.Join(dir, fileName)
err := copy(source, filepath.Join("testdata", fileName))
Expand All @@ -130,7 +136,12 @@ func stringerCompileAndRun(t *testing.T, dir, stringer, typeName, fileName, tran
}
stringSource := filepath.Join(dir, typeName+"_string.go")
// Run stringer in temporary directory.
err = run(stringer, "-type", typeName, "-output", stringSource, "-transform", transformNameMethod, source)
args := []string{"-type", typeName, "-output", stringSource, "-transform", transformNameMethod}
if useTypedErrors {
args = append(args, "-typederrors", "-values")
}
args = append(args, source)
err = run(stringer, args...)
if err != nil {
t.Fatal(err)
}
Expand Down
51 changes: 27 additions & 24 deletions enumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package main

import "fmt"

// Arguments to format are:
// [1]: type name
// Arguments to format are: [1]: type name [2]: complete error expression
const stringNameToValueMethod = `// %[1]sString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func %[1]sString(s string) (%[1]s, error) {
Expand All @@ -14,20 +13,18 @@ func %[1]sString(s string) (%[1]s, error) {
if val, ok := _%[1]sNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%%s does not belong to %[1]s values", s)
return 0, %[2]s
}
`

// Arguments to format are:
// [1]: type name
// Arguments to format are: [1]: type name
const stringValuesMethod = `// %[1]sValues returns all values of the enum
func %[1]sValues() []%[1]s {
return _%[1]sValues
}
`

// Arguments to format are:
// [1]: type name
// Arguments to format are: [1]: type name
const stringsMethod = `// %[1]sStrings returns a slice of all String values of the enum
func %[1]sStrings() []string {
strs := make([]string, len(_%[1]sNames))
Expand All @@ -36,8 +33,7 @@ func %[1]sStrings() []string {
}
`

// Arguments to format are:
// [1]: type name
// Arguments to format are: [1]: type name
const stringBelongsMethodLoop = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
func (i %[1]s) IsA%[1]s() bool {
for _, v := range _%[1]sValues {
Expand All @@ -49,17 +45,15 @@ func (i %[1]s) IsA%[1]s() bool {
}
`

// Arguments to format are:
// [1]: type name
// Arguments to format are: [1]: type name
const stringBelongsMethodSet = `// IsA%[1]s returns "true" if the value is listed in the enum definition. "false" otherwise
func (i %[1]s) IsA%[1]s() bool {
_, ok := _%[1]sMap[i]
return ok
}
`

// Arguments to format are:
// [1]: type name
// Arguments to format are: [1]: type name
const altStringValuesMethod = `func (%[1]s) Values() []string {
return %[1]sStrings()
}
Expand All @@ -70,7 +64,7 @@ func (g *Generator) buildAltStringValuesMethod(typeName string) {
g.Printf(altStringValuesMethod, typeName)
}

func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int) {
func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
// At this moment, either "g.declareIndexAndNameVars()" or "g.declareNameVars()" has been called

// Print the slice of values
Expand All @@ -89,7 +83,13 @@ func (g *Generator) buildBasicExtras(runs [][]Value, typeName string, runsThresh
g.printNamesSlice(runs, typeName, runsThreshold)

// Print the basic extra methods
g.Printf(stringNameToValueMethod, typeName)
var errorCode string
if useTypedErrors {
errorCode = fmt.Sprintf(`errors.Join(enumerrs.ErrValueInvalid, fmt.Errorf("%%s does not belong to %s values", s))`, typeName)
} else {
errorCode = fmt.Sprintf(`fmt.Errorf("%%s does not belong to %s values", s)`, typeName)
}
g.Printf(stringNameToValueMethod, typeName, errorCode)
g.Printf(stringValuesMethod, typeName)
g.Printf(stringsMethod, typeName)
if len(runs) <= runsThreshold {
Expand Down Expand Up @@ -143,8 +143,7 @@ func (g *Generator) printNamesSlice(runs [][]Value, typeName string, runsThresho
g.Printf("}\n\n")
}

// Arguments to format are:
// [1]: type name
// Arguments to format are: [1]: type name
const jsonMethods = `
// MarshalJSON implements the json.Marshaler interface for %[1]s
func (i %[1]s) MarshalJSON() ([]byte, error) {
Expand All @@ -164,12 +163,13 @@ func (i *%[1]s) UnmarshalJSON(data []byte) error {
}
`

func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThreshold int) {
func (g *Generator) buildJSONMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
// For now, just use the standard template
// We rely on the %[1]sString method to provide typed errors when enabled
g.Printf(jsonMethods, typeName)
}

// Arguments to format are:
// [1]: type name
// Arguments to format are: [1]: type name
const textMethods = `
// MarshalText implements the encoding.TextMarshaler interface for %[1]s
func (i %[1]s) MarshalText() ([]byte, error) {
Expand All @@ -184,12 +184,13 @@ func (i *%[1]s) UnmarshalText(text []byte) error {
}
`

func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThreshold int) {
func (g *Generator) buildTextMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
// For now, just use the standard template
// We rely on the %[1]sString method to provide typed errors when enabled
g.Printf(textMethods, typeName)
}

// Arguments to format are:
// [1]: type name
// Arguments to format are: [1]: type name
const yamlMethods = `
// MarshalYAML implements a YAML Marshaler for %[1]s
func (i %[1]s) MarshalYAML() (interface{}, error) {
Expand All @@ -209,6 +210,8 @@ func (i *%[1]s) UnmarshalYAML(unmarshal func(interface{}) error) error {
}
`

func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int) {
func (g *Generator) buildYAMLMethods(runs [][]Value, typeName string, runsThreshold int, useTypedErrors bool) {
// For now, just use the standard template
// We rely on the %[1]sString method to provide typed errors when enabled
g.Printf(yamlMethods, typeName)
}
8 changes: 8 additions & 0 deletions enumerrs/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package enumerrs

import "errors"

// This package defines custom error types for use in the generated code.

// ErrValueInvalid is returned when a value does not belong to the set of valid values for a type.
var ErrValueInvalid = errors.New("the input value is not valid for the type")
Loading