blob: 59eb24f36076a1acdb9fd8edbbb999916e9b0eea [file] [log] [blame]
// Copyright 2015 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package assertions
import (
"fmt"
"go.chromium.org/luci/common/errors"
"github.com/smarty/assertions"
)
// ShouldContainErr checks if an `errors.MultiError` on the left side contains
// as one of its errors an `error` or `string` on the right side. If nothing is
// provided on the right, checks that the left side contains at least one non-nil
// error. If nil is provided on the right, checks that the left side contains
// at least one nil, even if it contains other errors.
//
// Equivalent to calling ShouldErrLike on each `error` in an `errors.MultiError`
// and succeeding as long as one of the ShouldErrLike calls succeeds.
//
// To avoid confusion, explicitly rejects the special case where the right side is
// an `errors.MultiError`.
func ShouldContainErr(actual any, expected ...any) string {
if len(expected) > 1 {
return fmt.Sprintf("ShouldContainErr requires 0 or 1 expected value, got %d", len(expected))
}
if actual == nil {
return assertions.ShouldNotBeNil(actual)
}
me, ok := actual.(errors.MultiError)
if !ok {
return assertions.ShouldHaveSameTypeAs(actual, errors.MultiError{})
}
if len(expected) == 0 {
return assertions.ShouldNotBeNil(me.First())
}
switch expected[0].(type) {
case errors.MultiError:
return "expected value must not be a MultiError"
case string:
// Do nothing.
case error:
// Do nothing.
default:
if expected[0] != nil {
return fmt.Sprintf("unexpected argument type %T, expected string or error", expected[0])
}
}
for _, err := range me {
if ShouldErrLike(err, expected[0]) == "" {
return ""
}
}
return fmt.Sprintf("expected MultiError to contain %q", expected[0])
}
// ShouldErrLike compares an `error` or `string` on the left side, to `error`s
// or `string`s on the right side.
//
// If multiple errors/strings are provided on the righthand side, they must all
// be contained in the stringified error on the lefthand side.
//
// If the righthand side is the singluar `nil`, this expects the error to be
// nil.
//
// Example:
//
// // Usage Equivalent To
// So(err, ShouldErrLike, "custom") // `err.Error()` ShouldContainSubstring "custom"
// So(err, ShouldErrLike, io.EOF) // `err.Error()` ShouldContainSubstring io.EOF.Error()
// So(err, ShouldErrLike, "EOF") // `err.Error()` ShouldContainSubstring "EOF"
// So(err, ShouldErrLike,
// "thing", "other", "etc.") // `err.Error()` contains all of these substrings.
// So(nilErr, ShouldErrLike, nil) // nilErr ShouldBeNil
// So(nonNilErr, ShouldErrLike, "") // nonNilErr ShouldNotBeNil
func ShouldErrLike(actual any, expected ...any) string {
if len(expected) == 0 {
return "ShouldErrLike requires 1 or more expected values, got 0"
}
// If we have multiple expected arguments, they must all be non-nil
if len(expected) > 1 {
for _, e := range expected {
if e == nil {
return "ShouldErrLike only accepts `nil` on the right hand side as the sole argument."
}
}
}
if expected[0] == nil { // this can only happen if len(expected) == 1
return assertions.ShouldBeNil(actual)
} else if actual == nil {
return assertions.ShouldNotBeNil(actual)
}
ae, ok := actual.(error)
if !ok {
return assertions.ShouldImplement(actual, (*error)(nil))
}
for _, expect := range expected {
switch x := expect.(type) {
case string:
if ret := assertions.ShouldContainSubstring(ae.Error(), x); ret != "" {
return ret
}
case error:
if ret := assertions.ShouldContainSubstring(ae.Error(), x.Error()); ret != "" {
return ret
}
default:
return fmt.Sprintf("unexpected argument type %T, expected string or error", expect)
}
}
return ""
}
// ShouldPanicLike is the same as ShouldErrLike, but with the exception that it
// takes a panic'ing func() as its first argument, instead of the error itself.
func ShouldPanicLike(function any, expected ...any) (ret string) {
f, ok := function.(func())
if !ok {
return fmt.Sprintf("unexpected argument type %T, expected `func()`", function)
}
defer func() {
ret = ShouldErrLike(recover(), expected...)
}()
f()
return ShouldErrLike(nil, expected...)
}
// ShouldUnwrapTo asserts that an error, when unwrapped, equals another error.
//
// The actual field will be unwrapped using errors.Unwrap and then compared to
// the error in expected.
func ShouldUnwrapTo(actual any, expected ...any) string {
act, ok := actual.(error)
if !ok {
return fmt.Sprintf("ShouldUnwrapTo requires an error actual type, got %T", act)
}
if len(expected) != 1 {
return fmt.Sprintf("ShouldUnwrapTo requires exactly one expected value, got %d", len(expected))
}
exp, ok := expected[0].(error)
if !ok {
return fmt.Sprintf("ShouldUnwrapTo requires an error expected type, got %T", expected[0])
}
return assertions.ShouldEqual(errors.Unwrap(act), exp)
}