Merge pull request #61 from dnephin/assert-error-type

Add cmp.ErrorType
diff --git a/assert/assert.go b/assert/assert.go
index 82760a4..1896618 100644
--- a/assert/assert.go
+++ b/assert/assert.go
@@ -34,7 +34,7 @@
 	    assert.NilError(t, closer.Close())
 	    assert.Assert(t, is.Error(err, "the exact error message"))
 	    assert.Assert(t, is.ErrorContains(err, "includes this"))
-	    assert.Assert(t, os.IsNotExist(err), "got %+v", err)
+	    assert.Assert(t, is.ErrorType(err, os.IsNotExist))
 
 	    // complex types
 	    assert.DeepEqual(t, result, myStruct{Name: "title"})
diff --git a/assert/cmp/compare.go b/assert/cmp/compare.go
index 3072c14..64a25f0 100644
--- a/assert/cmp/compare.go
+++ b/assert/cmp/compare.go
@@ -237,3 +237,74 @@
 		return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", value, value.Type()))
 	}
 }
+
+// ErrorType succeeds if err is not nil and is of the expected type.
+//
+// Expected can be one of:
+// a func(error) bool which returns true if the error is the expected type,
+// an instance of a struct of the expected type,
+// a pointer to an interface the error is expected to implement,
+// a reflect.Type of the expected struct or interface.
+func ErrorType(err error, expected interface{}) Comparison {
+	return func() Result {
+		switch expectedType := expected.(type) {
+		case func(error) bool:
+			return cmpErrorTypeFunc(err, expectedType)
+		case reflect.Type:
+			if expectedType.Kind() == reflect.Interface {
+				return cmpErrorTypeImplementsType(err, expectedType)
+			}
+			return cmpErrorTypeEqualType(err, expectedType)
+		case nil:
+			return ResultFailure(fmt.Sprintf("invalid type for expected: nil"))
+		}
+
+		expectedType := reflect.TypeOf(expected)
+		switch {
+		case expectedType.Kind() == reflect.Struct:
+			return cmpErrorTypeEqualType(err, expectedType)
+		case isPtrToInterface(expectedType):
+			return cmpErrorTypeImplementsType(err, expectedType.Elem())
+		}
+		return ResultFailure(fmt.Sprintf("invalid type for expected: %T", expected))
+	}
+}
+
+func cmpErrorTypeFunc(err error, f func(error) bool) Result {
+	if f(err) {
+		return ResultSuccess
+	}
+	actual := "nil"
+	if err != nil {
+		actual = fmt.Sprintf("%s (%T)", err, err)
+	}
+	return ResultFailureTemplate(`error is {{ .Data.actual }}
+		{{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`,
+		map[string]interface{}{"actual": actual})
+}
+
+func cmpErrorTypeEqualType(err error, expectedType reflect.Type) Result {
+	if err == nil {
+		return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
+	}
+	errValue := reflect.ValueOf(err)
+	if errValue.Type() == expectedType {
+		return ResultSuccess
+	}
+	return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
+}
+
+func cmpErrorTypeImplementsType(err error, expectedType reflect.Type) Result {
+	if err == nil {
+		return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
+	}
+	errValue := reflect.ValueOf(err)
+	if errValue.Type().Implements(expectedType) {
+		return ResultSuccess
+	}
+	return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
+}
+
+func isPtrToInterface(typ reflect.Type) bool {
+	return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface
+}
diff --git a/assert/cmp/compare_test.go b/assert/cmp/compare_test.go
index 14a8f80..bf71858 100644
--- a/assert/cmp/compare_test.go
+++ b/assert/cmp/compare_test.go
@@ -348,3 +348,146 @@
 		t.Errorf("expected \n%q\ngot\n%q\n", expected, message)
 	}
 }
+
+type stubError struct{}
+
+func (s stubError) Error() string {
+	return "stub error"
+}
+
+func isErrorOfTypeStub(err error) bool {
+	return reflect.TypeOf(err) == reflect.TypeOf(stubError{})
+}
+
+type notStubError struct{}
+
+func (s notStubError) Error() string {
+	return "not stub error"
+}
+
+func isErrorOfTypeNotStub(err error) bool {
+	return reflect.TypeOf(err) == reflect.TypeOf(notStubError{})
+}
+
+type specialStubIface interface {
+	Special()
+}
+
+func TestErrorTypeWithNil(t *testing.T) {
+	var testcases = []struct {
+		name     string
+		expType  interface{}
+		expected string
+	}{
+		{
+			name:     "with struct",
+			expType:  stubError{},
+			expected: "error is nil, not cmp.stubError",
+		},
+		{
+			name:     "with interface",
+			expType:  (*specialStubIface)(nil),
+			expected: "error is nil, not cmp.specialStubIface",
+		},
+		{
+			name:     "with reflect.Type",
+			expType:  reflect.TypeOf(stubError{}),
+			expected: "error is nil, not cmp.stubError",
+		},
+	}
+	for _, testcase := range testcases {
+		t.Run(testcase.name, func(t *testing.T) {
+			result := ErrorType(nil, testcase.expType)()
+			assertFailure(t, result, testcase.expected)
+		})
+	}
+}
+
+func TestErrorTypeSuccess(t *testing.T) {
+	var testcases = []struct {
+		name    string
+		expType interface{}
+	}{
+		{
+			name:    "with function",
+			expType: isErrorOfTypeStub,
+		},
+		{
+			name:    "with struct",
+			expType: stubError{},
+		},
+		{
+			name:    "with interface",
+			expType: (*error)(nil),
+		},
+		{
+			name:    "with reflect.Type struct",
+			expType: reflect.TypeOf(stubError{}),
+		},
+		{
+			name:    "with reflect.Type interface",
+			expType: reflect.TypeOf((*error)(nil)).Elem(),
+		},
+	}
+	for _, testcase := range testcases {
+		t.Run(testcase.name, func(t *testing.T) {
+			result := ErrorType(stubError{}, testcase.expType)()
+			assertSuccess(t, result)
+		})
+	}
+}
+
+func TestErrorTypeFailure(t *testing.T) {
+	var testcases = []struct {
+		name     string
+		expType  interface{}
+		expected string
+	}{
+		{
+			name:     "with struct",
+			expType:  notStubError{},
+			expected: "error is stub error (cmp.stubError), not cmp.notStubError",
+		},
+		{
+			name:     "with interface",
+			expType:  (*specialStubIface)(nil),
+			expected: "error is stub error (cmp.stubError), not cmp.specialStubIface",
+		},
+		{
+			name:     "with reflect.Type struct",
+			expType:  reflect.TypeOf(notStubError{}),
+			expected: "error is stub error (cmp.stubError), not cmp.notStubError",
+		},
+		{
+			name:     "with reflect.Type interface",
+			expType:  reflect.TypeOf((*specialStubIface)(nil)).Elem(),
+			expected: "error is stub error (cmp.stubError), not cmp.specialStubIface",
+		},
+	}
+	for _, testcase := range testcases {
+		t.Run(testcase.name, func(t *testing.T) {
+			result := ErrorType(stubError{}, testcase.expType)()
+			assertFailure(t, result, testcase.expected)
+		})
+	}
+}
+
+func TestErrorTypeInvalid(t *testing.T) {
+	result := ErrorType(stubError{}, nil)()
+	assertFailure(t, result, "invalid type for expected: nil")
+
+	result = ErrorType(stubError{}, "my type!")()
+	assertFailure(t, result, "invalid type for expected: string")
+}
+
+func TestErrorTypeWithFunc(t *testing.T) {
+	result := ErrorType(nil, isErrorOfTypeStub)()
+	assertFailureTemplate(t, result,
+		[]ast.Expr{nil, &ast.Ident{Name: "isErrorOfTypeStub"}},
+		"error is nil, not isErrorOfTypeStub")
+
+	result = ErrorType(stubError{}, isErrorOfTypeNotStub)()
+	assertFailureTemplate(t, result,
+		[]ast.Expr{nil, &ast.Ident{Name: "isErrorOfTypeNotStub"}},
+		"error is stub error (cmp.stubError), not isErrorOfTypeNotStub")
+}