blob: 7021b328e4a840107dd266bb5f0e8a11e78fb771 [file] [log] [blame]
// Copyright 2021 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Package xcontext provides Context with custom errors.
package xcontext
import (
"context"
"errors"
"sync/atomic"
"time"
"code.cloudfoundry.org/clock"
)
// clk is replaced in unit tests to use fake clocks.
var clk = clock.NewClock()
type keyType string
const (
contextTimeoutKey keyType = "context_timeout_duration"
)
// CancelFunc is a function to cancel an associated context with a specified
// error. If a context is already canceled, calling this function has no effect.
// It panics if err is nil.
// Upon returning from this function, an associated context is guaranteed to be
// in a canceled state (i.e. Done channel is closed, Err returns non-nil).
type CancelFunc func(err error)
// contextImpl implements context.Context with custom errors.
type contextImpl struct {
// parent is a parent context.
parent context.Context
// hasDeadline indicates whether this context has a deadline.
hasDeadline bool
// deadline is a deadline of this context. It is valid only when
// hasDeadline is true.
deadline time.Time
// done is a channel returned by Done.
done chan struct{}
// req is a channel over which cancellation errors are sent. The channel
// has capacity=1 so that sending a first error over it does not block.
req chan error
// errValue holds an error value returned by Err.
errValue atomic.Value
}
// newContext returns a new context. It also starts a background goroutine to
// handle cancellation signals if needed.
//
// If deadlineErr is nil, a new context has the same deadline as its parent, and
// reqDeadline is ignored. If deadlineErr is non-nil, the deadline of a new
// context is set to reqDeadline or that of the parent context, whichever comes
// earlier.
func newContext(parent context.Context, deadlineErr error, reqDeadline time.Time) (context.Context, CancelFunc) {
newDeadline := false
deadline, hasDeadline := parent.Deadline()
if deadlineErr != nil && (!hasDeadline || reqDeadline.Before(deadline)) {
deadline = reqDeadline
hasDeadline = true
newDeadline = true
}
ctx := &contextImpl{
parent: parent,
hasDeadline: hasDeadline,
deadline: deadline,
done: make(chan struct{}),
req: make(chan error, 1),
}
// Handle the cases where the new context is immediately canceled.
if err := func() error {
if err := parent.Err(); err != nil {
return err
}
if newDeadline && !deadline.After(clk.Now()) {
return deadlineErr
}
return nil
}(); err != nil {
ctx.errValue.Store(err)
close(ctx.done)
return ctx, ctx.cancel
}
// Start a background goroutine that handles cancellation signals.
go func() {
err := func() error {
var dl <-chan time.Time
if newDeadline {
tm := clk.NewTimer(deadline.Sub(clk.Now()))
defer tm.Stop()
dl = tm.C()
}
select {
case <-parent.Done():
return parent.Err()
case <-dl:
return deadlineErr
case err := <-ctx.req:
return err
}
}()
ctx.errValue.Store(err)
close(ctx.done)
}()
return ctx, ctx.cancel
}
// Deadline returns the deadline of the context.
func (c *contextImpl) Deadline() (deadline time.Time, ok bool) {
return c.deadline, c.hasDeadline
}
// Done returns a channel that is closed on cancellation of the context.
func (c *contextImpl) Done() <-chan struct{} {
return c.done
}
// Err returns a non-nil error if the context has been canceled.
// This method does not strictly follow the contract of the context.Context
// interface; it may return an error different from context.Canceled or
// context.DeadlineExceeded.
func (c *contextImpl) Err() error {
if val := c.errValue.Load(); val != nil {
return val.(error)
}
return nil
}
// Value returns a value associated with the context.
func (c *contextImpl) Value(key interface{}) interface{} {
return c.parent.Value(key)
}
// cancel requests to cancel the context.
func (c *contextImpl) cancel(err error) {
if err == nil {
panic("xcontext: Cancel called with nil")
}
// Attempt to send an error to the background goroutine.
// req has capacity=1, so at least the first send should succeed.
select {
case c.req <- err:
default:
}
// Wait until the context is canceled.
<-c.done
}
// WithCancel returns a context that can be canceled with arbitrary errors.
func WithCancel(parent context.Context) (context.Context, CancelFunc) {
return newContext(parent, nil, time.Time{})
}
// WithDeadline returns a context that can be canceled with arbitrary errors on
// reaching a specified deadline. It panics if err is nil.
func WithDeadline(parent context.Context, t time.Time, err error) (context.Context, CancelFunc) {
if err == nil {
panic("xcontext: WithDeadline called with nil err")
}
return newContext(parent, err, t)
}
// WithTimeout returns a context that can be canceled with arbitrary errors on
// reaching a specified timeout. It panics if err is nil.
func WithTimeout(parent context.Context, d time.Duration, err error) (context.Context, CancelFunc) {
if err == nil {
panic("xcontext: WithTimeout called with nil err")
}
ctx, cFunc := WithDeadline(parent, clk.Now().Add(d), err)
// Adding timeout to the context for access when reporting the timeout has been reached
ctx = context.WithValue(ctx, contextTimeoutKey, d)
return ctx, cFunc
}
// GetContextTimeout returns the duration set on the Context for its Timeout
// If the Timeout has not been set, then -1 is returned
func GetContextTimeout(ctx context.Context) (time.Duration, error) {
t, ok := ctx.Value(contextTimeoutKey).(time.Duration)
if ok {
return t, nil
}
return 0, errors.New("timeout not set on context")
}