blob: ba4da0b38b9365c32c100e21960aa6b24c3652a8 [file] [log] [blame]
// Copyright 2017 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tqtesting can be used in unit tests to simulate task queue calls
// produced by tq.Dispatcher.
package tqtesting
import (
"context"
"fmt"
"sort"
"strings"
"time"
"github.com/golang/protobuf/proto"
"go.chromium.org/gae/service/taskqueue"
"go.chromium.org/luci/appengine/tq"
"go.chromium.org/luci/common/clock"
"go.chromium.org/luci/common/clock/testclock"
)
// GetTestable returns an interface for TQ intended to be used only from tests.
//
// Panics if used with production Task Queue implementation.
func GetTestable(ctx context.Context, d *tq.Dispatcher) Testable {
tqt := taskqueue.GetTestable(ctx)
if tqt == nil {
panic("not a testable task queue implementation")
}
internals := d.Internals().(dispatcherInternals)
return &testableImpl{internals, tqt}
}
// Testable can be used from unit tests that posts TQ tasks.
//
// It assumes Dispatcher is in complete control of all Task Queue tasks, e.g.
// if some handlers add task queue tasks directly, they are going to be
// clobbered by Testable's PopTasks.
type Testable interface {
// CreateQueues creates all push queues used by registered tasks.
CreateQueues()
// GetScheduledTasks fetches all scheduled tasks.
//
// Returned tasks are sorted by ETA (earliest first) and Name.
GetScheduledTasks() TaskList
// ExecuteTask executes a handler for the given task in a derivative of a
// given context.
//
// Returns whatever the handle returns or a general error if the task can't
// be dispatched.
ExecuteTask(ctx context.Context, task Task, hdr *taskqueue.RequestHeaders) error
// RunSimulation simulates task queue service by running enqueued tasks.
//
// It looks at all pending tasks, picks the one with smallest ETA, moves the
// test clock and executes the task, looks at all pending tasks again, picks
// the one with smallest ETA, and so on ...
//
// Panics if there's no test clock in the context. Assumes complete control
// of the task queue service (e.g if something is popping or resetting tasks
// in parallel, bad things will happen).
//
// If it encounters an unrecognized task, calls params.UnknownTaskHandler to
// handle it. Unrecognized tasks are still returned in 'executed' and
// 'pending' sets, except they don't have 'Payload' set.
//
// It stops whenever any of the following happens:
// * The queue of pending tasks is empty.
// * ETA of the next task is past deadline (set via SimulationParams).
// * ShouldStopBefore(...) returns true for the next to-be-executed task.
// * ShouldStopAfter(...) returns true for the just-executed task.
// * A task returns an error. The bad task will be last in 'executed' list.
//
// Returns:
// executed: executed tasks, in order of their execution.
// pending: tasks to be executed (when hitting a deadline or an error).
// err: an error produced by the failed task (when exiting on an error).
RunSimulation(ctx context.Context, params *SimulationParams) (executed, pending TaskList, err error)
}
// SimulationParams are passed to RunSimulation.
type SimulationParams struct {
Deadline time.Time // default is "don't stop on deadline"
ShouldStopBefore func(t Task) bool // returns true if simulation should stop
ShouldStopAfter func(t Task) bool // returns true if simulation should stop
UnknownTaskHandler func(t *taskqueue.Task) error // handles unrecognized tasks
}
// Task represents a scheduled tq Task.
type Task struct {
Task *taskqueue.Task // original task queue task
Payload proto.Message // deserialized payload or nil if unrecognized
}
// dispatcherInternals is secretly exposed by Dispatcher.Internals.
//
// Implemented in tq.internalsImpl.
//
// BEWARE: There are no compile time type checks here. If you add a method or
// modify existing one make sure tq.internalsImpl is modified accordingly. Unit
// tests will fail if something is not right.
type dispatcherInternals interface {
GetBaseURL() string
GetAllQueues() []string
GetPayload(blob []byte) (proto.Message, error)
GetHandler(payload proto.Message) (cb tq.Handler, q string, err error)
WithRequestHeaders(ctx context.Context, hdr *taskqueue.RequestHeaders) context.Context
}
////////////////////////////////////////////////////////////////////////////////
// TaskList is a sortable list of Task structs.
type TaskList []Task
func (l TaskList) Len() int { return len(l) }
func (l TaskList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
func (l TaskList) Less(i, j int) bool { return l[i].isLessThan(&l[j]) }
func (t *Task) isLessThan(other *Task) bool {
switch {
case t.Task.ETA.Before(other.Task.ETA):
return true
case other.Task.ETA.Before(t.Task.ETA):
return false
}
return t.Task.Name < other.Task.Name
}
// Payloads collects payloads of each task into a slice.
//
// Useful when writing asserts in tests.
func (l TaskList) Payloads() (out []proto.Message) {
out = make([]proto.Message, len(l))
for i, t := range l {
out[i] = t.Payload
}
return
}
////////////////////////////////////////////////////////////////////////////////
// testableImpl actually implements Testable.
type testableImpl struct {
d dispatcherInternals
tqt taskqueue.Testable
}
func (t *testableImpl) CreateQueues() {
for _, q := range t.d.GetAllQueues() {
if q != "default" { // "default" queue always exists
t.tqt.CreateQueue(q)
}
}
}
func (t *testableImpl) GetScheduledTasks() (out TaskList) {
baseURL := t.d.GetBaseURL()
for _, tasks := range t.tqt.GetScheduledTasks() {
for _, task := range tasks {
// Handle only tasks submitted by tq.Dispatcher.
if strings.HasPrefix(task.Path, baseURL) {
payload, _ := t.d.GetPayload(task.Payload)
out = append(out, Task{
Task: task,
Payload: payload,
})
}
}
}
sort.Sort(out)
return
}
func (t *testableImpl) ExecuteTask(ctx context.Context, task Task, hdr *taskqueue.RequestHeaders) error {
if task.Payload == nil {
return fmt.Errorf("can't execute a task without payload, not a tq task?")
}
cb, q, err := t.d.GetHandler(task.Payload)
if err != nil {
return err
}
headers := taskqueue.RequestHeaders{}
if hdr != nil {
headers = *hdr
}
headers.QueueName = q
headers.TaskName = task.Task.Name
headers.TaskETA = task.Task.ETA
return cb(t.d.WithRequestHeaders(ctx, &headers), task.Payload)
}
////////////////////////////////////////////////////////////////////////////////
func (t *testableImpl) RunSimulation(ctx context.Context, params *SimulationParams) (executed, pending TaskList, err error) {
tc := clock.Get(ctx).(testclock.TestClock)
var deadline time.Time
var shouldStopBefore func(t Task) bool
var shouldStopAfter func(t Task) bool
var unknownHandler func(t *taskqueue.Task) error
if params != nil {
deadline = params.Deadline
shouldStopBefore = params.ShouldStopBefore
shouldStopAfter = params.ShouldStopAfter
unknownHandler = params.UnknownTaskHandler
}
loop:
for {
earliest, queue := t.pickEarliestETA()
switch {
case earliest == nil:
break loop // no more tasks
case !deadline.IsZero() && earliest.Task.ETA.After(deadline):
break loop // deadline reached
case shouldStopBefore != nil && shouldStopBefore(*earliest):
break loop // stop condition reached
}
if err = taskqueue.Delete(ctx, queue, earliest.Task); err != nil {
panic("impossible, the task must be in the queue")
}
executed = append(executed, *earliest)
tc.Set(earliest.Task.ETA)
if earliest.Payload == nil {
if unknownHandler != nil {
err = unknownHandler(earliest.Task)
} else {
err = fmt.Errorf("unrecognized TQ task for handler at %s", earliest.Task.Path)
}
} else {
err = t.ExecuteTask(ctx, *earliest, nil)
}
if err != nil {
break
}
if shouldStopAfter != nil && shouldStopAfter(*earliest) {
break
}
}
pending = t.GetScheduledTasks()
return
}
func (t *testableImpl) pickEarliestETA() (earliest *Task, queue string) {
// TODO(vadimsh): This is horribly inefficient in case there are large number
// of pending tasks. If it becomes an issue, taskqueue's Testable interface
// should be modified to return earliest task directly (it can pick it more
// efficiently, since it stores tasks in sorted priority queues already).
for q, tasks := range t.tqt.GetScheduledTasks() {
for _, task := range tasks {
payload, _ := t.d.GetPayload(task.Payload)
tt := Task{
Task: task,
Payload: payload,
}
if earliest == nil || tt.isLessThan(earliest) {
earliest = &tt
queue = q
}
}
}
return
}