Skip to content

Instantly share code, notes, and snippets.

@choonkeat
Last active November 28, 2023 14:15
Show Gist options
  • Save choonkeat/74002057e3c74ebc3b4428b0161a80a7 to your computer and use it in GitHub Desktop.
Save choonkeat/74002057e3c74ebc3b4428b0161a80a7 to your computer and use it in GitHub Desktop.
Discriminated union for Go
package gosumtype
import (
"time"
)
// To define this sum type:
//
// type User
// = Anonymous
// | Member String Time
// | Admin String
//
// Ideally, we just code something like this and the
// rest of the boiler plate can be generated
type User interface {
Switch(s UserScenarios)
}
type UserScenarios struct {
Anonymous func()
Member func(email string, since time.Time)
Admin func(email string)
}
package gosumtype
import (
"log"
"time"
)
// Example usage
func Caller() {
user1 := Anonymous()
user2 := Member("Alice", time.Now())
user3 := Admin("Bob")
log.Println(
"User1:", UserString(user1),
"User2:", UserString(user2),
"User3:", UserString(user3),
)
}
func UserString(u User) string {
var result string
u.Switch(UserScenarios{
Anonymous: func() {
result = "anonymous coward"
},
Member: func(email string, since time.Time) {
result = email + " (member since " + since.String() + ")"
},
Admin: func(email string) {
result = email + " (admin)"
},
})
return result
}
package gosumtype
import (
"log"
"strconv"
"testing"
"time"
)
func TestUser(t *testing.T) {
testCases := []struct {
givenUser User
}{
{
givenUser: Anonymous(),
},
{
givenUser: Member("[email protected]", time.Now()),
},
{
givenUser: Admin("[email protected]"),
},
}
for i, tc := range testCases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
// using names are very helpful, but loses the exhaustive check at compile time
// since Go happily set the undefined scenarios as function zero value: nil
//
// but we can use https://golangci-lint.run/usage/linters/#exhaustruct
// to check at CI instead of suffering from zero value at runtime
tc.givenUser.Switch(UserScenarios{
Anonymous: func() {
log.Println("i am anonymous")
},
Member: func(email string, since time.Time) {
log.Println("member", email, since)
},
Admin: func(email string) {
log.Println("admin", email)
},
})
})
}
}
package gosumtype
import "time"
//
// Boiler plate code below:
//
// Anonymous
type anonymous struct{}
func (a anonymous) Switch(s UserScenarios) { s.Anonymous() }
func Anonymous() User {
return anonymous{}
}
// Member string time.Time
type member struct {
email string
since time.Time
}
func (m member) Switch(s UserScenarios) { s.Member(m.email, m.since) }
func Member(email string, since time.Time) User {
return member{email, since}
}
// Admin string
type admin struct{ email string }
func (a admin) Switch(s UserScenarios) { s.Admin(a.email) }
func Admin(email string) User {
return admin{email}
}
@elfgoh
Copy link

elfgoh commented Aug 25, 2019

Unsure if it meets your requirements, but I used this validator library in this example https://github.com/go-playground/validator

package gosumtype

import (
	"fmt"
	"log"
	"strconv"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
	"gopkg.in/go-playground/validator.v9"
)

// To define this sum type:
//
//    type User
//        = Anonymous
//        | Member String Time
//        | Admin String
//
// Ideally, we just code something like this and the
// rest of the boiler plate can be generated
//
type User interface {
	Switch(s UserScenarios)
}

type UserScenarios struct {
	Anonymous func()
	Member    func(email string, since time.Time)
	Admin     func(email string)
}

func New(anonymous func(), member func(string, time.Time), admin func(string)) UserScenarios {
	return UserScenarios{
		Anonymous: anonymous,
		Member:    member,
		Admin:     admin,
	}
}

type PrivilegedUser interface {
	Switch(s UserScenariosMissingField)
}

type UserScenariosMissingField struct {
	God  func(email string, since time.Time) `validate:"required"`
	Root func(email string)                  `validate:"required"`
}

func NewMissingField(god func(string, time.Time), root func(string)) UserScenariosMissingField {
	return UserScenariosMissingField{
		//God:  god,
		Root: root,
	}
}

// Example using:

func TestFunction(t *testing.T) {
	testCases := []struct {
		givenUser User
	}{
		{
			givenUser: Anonymous(),
		},
		{
			givenUser: Member("[email protected]", time.Now()),
		},
		{
			givenUser: Admin("[email protected]"),
		},
	}

	for i, tc := range testCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {

			// using unnamed struct allows for exhaustive compile time check
			// but scenarios are hard to differentiate; readability suffers
			// tc.givenUser.Switch(UserScenarios{
			// 	func() {
			// 		log.Println("i am anonymous")
			// 	},
			// 	func(email string, since time.Time) {
			// 		log.Println("member", email, since)
			// 	},
			// 	func(email string) {
			// 		log.Println("admin", email)
			// 	},
			// })

			// using names are very helpful, but loses the exhaustive check at compile time
			// since Go happily set the undefined scenarios as function zero value: nil
			tc.givenUser.Switch(UserScenarios{
				Anonymous: func() {
					log.Println("i am anonymous")
				},
				Member: func(email string, since time.Time) {
					log.Println("member", email, since)
				},
				Admin: func(email string) {
					log.Println("admin", email)
				},
			})

			require.Equal(t, 1, 2)

		})
	}
}

func TestNewStructFieldSuccess(t *testing.T) {
	testCases := []struct {
		givenUser User
	}{
		{
			givenUser: Anonymous(),
		},
		{
			givenUser: Member("[email protected]", time.Now()),
		},
		{
			givenUser: Admin("[email protected]"),
		},
	}

	for i, tc := range testCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {

			anonymous := func() {
				log.Println("i am anonymous")
			}
			member := func(email string, since time.Time) {
				log.Println("member", email, since)
			}
			admin := func(email string) {
				log.Println("admin", email)
			}

			userScenarios := New(anonymous, member, admin)

			validate := validator.New()
			err := validate.Struct(userScenarios)
			if err != nil {
				if _, ok := err.(*validator.InvalidValidationError); ok {
					fmt.Println(err)
					return
				}

				for _, err := range err.(validator.ValidationErrors) {
					fmt.Println("The following field is missing: ", err.Field())
				}
			}

			tc.givenUser.Switch(userScenarios)

			require.Equal(t, 1, 2)

		})
	}
}

func TestNewStructFieldMissing(t *testing.T) {
	testCases := []struct {
		givenUser PrivilegedUser
	}{
		// {
		// 	givenUser: God("[email protected]", time.Now()),
		// },
		{
			givenUser: Root("[email protected]"),
		},
	}

	for i, tc := range testCases {
		t.Run(strconv.Itoa(i), func(t *testing.T) {

			god := func(email string, since time.Time) {
				log.Println("god", email, since)
			}

			root := func(email string) {
				log.Println("root", email)
			}
			userScenarios := NewMissingField(god, root)

			validate := validator.New()
			err := validate.Struct(userScenarios)
			if err != nil {
				if _, ok := err.(*validator.InvalidValidationError); ok {
					fmt.Println(err)
					return
				}

				for _, err := range err.(validator.ValidationErrors) {
					fmt.Println("The following field is missing: ", err.Field())
				}
			}

			tc.givenUser.Switch(userScenarios)

			require.Equal(t, 1, 2)

		})
	}
}

//
// Boiler plate code below:
//

// Anonymous
type anonymous struct{}

func (a anonymous) Switch(s UserScenarios) { s.Anonymous() }

func Anonymous() User {
	return anonymous{}
}

// Member string time.Time
type member struct {
	email string
	since time.Time
}

func (m member) Switch(s UserScenarios) { s.Member(m.email, m.since) }

func Member(email string, since time.Time) User {
	return member{email, since}
}

// Admin string
type admin struct{ email string }

func (a admin) Switch(s UserScenarios) { s.Admin(a.email) }

func Admin(email string) User {
	return admin{email}
}

// Root string
type root struct{ email string }

func (r root) Switch(s UserScenariosMissingField) { s.Root(r.email) }

func Root(email string) PrivilegedUser {
	return root{email}
}

@choonkeat
Copy link
Author

not really, but it did pose a possible approach. the downside is that it is a runtime check, not a compile-time check

diff --git a/gosumtype_test.go b/gosumtype_test.go
index ac70d67..2738c36 100644
--- a/gosumtype_test.go
+++ b/gosumtype_test.go
@@ -2,6 +2,7 @@ package gosumtype
 
 import (
 	"log"
+	"reflect"
 	"strconv"
 	"testing"
 	"time"
@@ -59,7 +60,7 @@ func TestFunction(t *testing.T) {
 				func(email string) {
 					log.Println("admin", email)
 				},
-			})
+			}.Exhaustive())
 
 			// using names are very helpful, but loses the exhaustive check at compile time
 			// since Go happily set the undefined scenarios as function zero value: nil
@@ -73,7 +74,7 @@ func TestFunction(t *testing.T) {
 				Admin: func(email string) {
 					log.Println("admin", email)
 				},
-			})
+			}.Exhaustive())
 		})
 	}
 }
@@ -82,6 +83,18 @@ func TestFunction(t *testing.T) {
 // Boiler plate code below:
 //
 
+// Runtime exhaustive check
+func (s UserScenarios) Exhaustive() UserScenarios {
+	valueOf := reflect.ValueOf(&s).Elem()
+	typeOf := valueOf.Type()
+	for i := 0; i < valueOf.NumField(); i++ {
+		if valueOf.Field(i).IsNil() {
+			panic(typeOf.Field(i).Name + " is not covered")
+		}
+	}
+	return s
+}
+
 // Anonymous
 type anonymous struct{}
 

@choonkeat-govtech
Copy link

regarding

using names are very helpful, but loses the exhaustive check at compile time

actually we have https://golangci-lint.run/usage/linters/#exhaustruct now 🌈

@yongchongye
Copy link

@choonkeat
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment