blob: 4f9412e05fc36eaedc78aaedeb34deee39a2a1af [file] [log] [blame]
// Copyright 2016 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package coordinator
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"go.chromium.org/luci/common/clock/testclock"
"go.chromium.org/luci/common/errors"
"go.chromium.org/luci/common/retry/transient"
"go.chromium.org/luci/logdog/common/types"
. "github.com/smartystreets/goconvey/convey"
)
// testCoordinator is an implementation of Coordinator that can be used for
// testing.
type testCoordinator struct {
// calls is the number of calls made to the interface's methods.
calls int32
// callC, if not nil, will have a token pushed to it when a call is made.
callC chan struct{}
// errC is a channel that error status will be read from if not nil.
errC chan error
}
func (c *testCoordinator) RegisterStream(ctx context.Context, s *LogStreamState, desc []byte) (
*LogStreamState, error) {
if err := c.incCalls(); err != nil {
return nil, err
}
// Set the ProtoVersion to differentiate the output State from the input.
rs := *s
rs.ProtoVersion = "remote"
return &rs, nil
}
func (c *testCoordinator) TerminateStream(ctx context.Context, tr *TerminateRequest) error {
if err := c.incCalls(); err != nil {
return err
}
return nil
}
// incCalls is an entry point for client goroutines. It offers the opportunity
// to track call count as well as trap executing goroutines within client calls.
//
// This must not be called while the lock is held, else it could lead to
// deadlock if multiple goroutines are trapped.
func (c *testCoordinator) incCalls() error {
if c.callC != nil {
c.callC <- struct{}{}
}
atomic.AddInt32(&c.calls, 1)
if c.errC != nil {
return <-c.errC
}
return nil
}
func TestStreamStateCache(t *testing.T) {
t.Parallel()
Convey(`Using a test configuration`, t, func() {
c, tc := testclock.UseTime(context.Background(), testclock.TestTimeLocal)
tcc := testCoordinator{}
st := LogStreamState{
Project: "test-project",
Path: "test-stream-path",
ID: "hash12345",
TerminalIndex: -1,
}
tr := TerminateRequest{
Project: st.Project,
Path: st.Path,
ID: st.ID,
TerminalIndex: 1337,
Secret: st.Secret,
}
// Note: In all of these tests, we check if "proto" field (ProtoVersion)
// is "remote". We use ProtoVersion as a channel between our fake remote
// service. When our fake remote service returns a LogStreamState, it sets
// "remote" to true to differentiate it from the local pushed state.
//
// If a LogStreamState has "remote" set to true, that implies that it was
// sent by the fake testing service rather than the local test.
Convey(`A streamStateCache`, func() {
ssc := NewCache(&tcc, 4, 1*time.Second)
resultC := make(chan *LogStreamState)
req := func(st *LogStreamState) {
var res *LogStreamState
defer func() {
resultC <- res
}()
st, err := ssc.RegisterStream(c, st, nil)
if err == nil {
res = st
}
}
Convey(`Can register a stream`, func() {
s, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldBeNil)
So(s.ProtoVersion, ShouldEqual, "remote")
So(tcc.calls, ShouldEqual, 1)
Convey(`Will not re-register the same stream.`, func() {
st.ProtoVersion = ""
s, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldBeNil)
So(s.ProtoVersion, ShouldEqual, "remote")
So(tcc.calls, ShouldEqual, 1)
})
Convey(`When the registration expires`, func() {
st.ProtoVersion = ""
tc.Add(time.Second)
Convey(`Will re-register the stream.`, func() {
s, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldBeNil)
So(s.ProtoVersion, ShouldEqual, "remote")
So(tcc.calls, ShouldEqual, 2)
})
})
Convey(`Can terminate a registered stream`, func() {
So(ssc.TerminateStream(c, &tr), ShouldBeNil)
So(tcc.calls, ShouldEqual, 2) // +1 call
Convey(`Registering the stream will include the terminal index.`, func() {
// Fill it in with junk to make sure we are getting cached.
st.TerminalIndex = 123
st.ProtoVersion = ""
s, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldBeNil)
So(s.ProtoVersion, ShouldEqual, "remote")
So(s.TerminalIndex, ShouldEqual, 1337)
So(tcc.calls, ShouldEqual, 2) // No additional calls.
})
})
})
Convey(`Can register a stream with a terminal index`, func() {
st.TerminalIndex = 1337
s, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldBeNil)
So(s.ProtoVersion, ShouldEqual, "remote")
So(tcc.calls, ShouldEqual, 1)
Convey(`A subsequent call to TerminateStream will be ignored, since we have remote terminal confirmation.`, func() {
tr.TerminalIndex = 12345
So(ssc.TerminateStream(c, &tr), ShouldBeNil)
So(tcc.calls, ShouldEqual, 1) // (No additional calls)
Convey(`A register stream call will return the confirmed terminal index.`, func() {
st.TerminalIndex = 0
s, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldBeNil)
So(s.ProtoVersion, ShouldEqual, "remote")
So(tcc.calls, ShouldEqual, 1) // (No additional calls)
So(s.TerminalIndex, ShouldEqual, 1337)
})
})
Convey(`A subsqeuent register stream call will return the confirmed terminal index.`, func() {
st.TerminalIndex = 0
s, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldBeNil)
So(s.ProtoVersion, ShouldEqual, "remote")
So(tcc.calls, ShouldEqual, 1) // (No additional calls)
So(s.TerminalIndex, ShouldEqual, 1337)
})
})
Convey(`When multiple goroutines register the same stream, it gets registered once.`, func() {
tcc.callC = make(chan struct{})
tcc.errC = make(chan error)
errs := make(errors.MultiError, 256)
for i := 0; i < len(errs); i++ {
go req(&st)
}
<-tcc.callC
tcc.errC <- nil
for i := 0; i < len(errs); i++ {
<-resultC
}
So(errors.SingleError(errs), ShouldBeNil)
So(tcc.calls, ShouldEqual, 1)
})
Convey(`Multiple registrations for the same stream will result in two requests if the first expires.`, func() {
tcc.callC = make(chan struct{})
tcc.errC = make(chan error)
// First request.
go req(&st)
// Wait for the request to happen, then advance time past the request's
// expiration.
<-tcc.callC
tc.Add(time.Second)
// Second request.
go req(&st)
// Release both calls and reap the results.
<-tcc.callC
tcc.errC <- nil
tcc.errC <- nil
r1 := <-resultC
r2 := <-resultC
So(r1.ProtoVersion, ShouldEqual, "remote")
So(r2.ProtoVersion, ShouldEqual, "remote")
So(tcc.calls, ShouldEqual, 2)
})
Convey(`RegisterStream`, func() {
Convey(`A transient registration error will result in a RegisterStream error.`, func() {
tcc.errC = make(chan error, 1)
tcc.errC <- errors.New("test error", transient.Tag)
_, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldNotBeNil)
So(tcc.calls, ShouldEqual, 1)
Convey(`A second request will call through, try again, and succeed.`, func() {
tcc.errC = nil
_, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldBeNil)
So(tcc.calls, ShouldEqual, 2)
})
})
Convey(`A non-transient registration error will result in a RegisterStream error.`, func() {
tcc.errC = make(chan error, 1)
tcc.errC <- errors.New("test error")
_, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldNotBeNil)
So(tcc.calls, ShouldEqual, 1)
Convey(`A second request will return the cached error.`, func() {
tcc.errC = nil
_, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldNotBeNil)
So(tcc.calls, ShouldEqual, 1)
})
})
})
Convey(`TerminateStream`, func() {
tr := TerminateRequest{
Project: st.Project,
ID: st.ID,
TerminalIndex: 1337,
}
Convey(`The termination endpoint returns a transient error, it will propagate.`, func() {
tcc.errC = make(chan error, 1)
tcc.errC <- errors.New("test error", transient.Tag)
err := ssc.TerminateStream(c, &tr)
So(transient.Tag.In(err), ShouldBeTrue)
So(tcc.calls, ShouldEqual, 1)
Convey(`A second attempt will call through, try again, and succeed.`, func() {
tcc.errC = nil
err := ssc.TerminateStream(c, &tr)
So(err, ShouldBeNil)
So(tcc.calls, ShouldEqual, 2)
})
})
Convey(`When the termination endpoint returns a non-transient error, it will propagate.`, func() {
tcc.errC = make(chan error, 1)
tcc.errC <- errors.New("test error")
err := ssc.TerminateStream(c, &tr)
So(err, ShouldNotBeNil)
So(tcc.calls, ShouldEqual, 1)
Convey(`A second request will return the cached error.`, func() {
tcc.errC = nil
err := ssc.TerminateStream(c, &tr)
So(err, ShouldNotBeNil)
So(tcc.calls, ShouldEqual, 1)
})
})
})
Convey(`Different projects with the same stream name will not conflict.`, func() {
var projects = []string{"", "foo", "bar"}
for i, p := range projects {
st.Project = p
s, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldBeNil)
So(ssc.TerminateStream(c, &TerminateRequest{
Project: s.Project,
Path: s.Path,
ID: s.ID,
TerminalIndex: types.MessageIndex(i),
}), ShouldBeNil)
}
So(tcc.calls, ShouldEqual, len(projects)*2)
for i, p := range projects {
st.Project = p
st.TerminalIndex = -1
s, err := ssc.RegisterStream(c, &st, nil)
So(err, ShouldBeNil)
So(s.TerminalIndex, ShouldEqual, types.MessageIndex(i))
}
So(tcc.calls, ShouldEqual, len(projects)*2)
})
})
Convey(`A streamStateCache can register multiple streams at once.`, func() {
ssc := NewCache(&tcc, 0, 0)
tcc.callC = make(chan struct{})
tcc.errC = make(chan error)
count := 2048
wg := sync.WaitGroup{}
errs := make(errors.MultiError, count)
state := make([]*LogStreamState, count)
wg.Add(count)
for i := 0; i < count; i++ {
st := st
st.Path = types.StreamPath(fmt.Sprintf("ID:%d", i))
go func(i int) {
defer wg.Done()
state[i], errs[i] = ssc.RegisterStream(c, &st, nil)
}(i)
}
// Wait for all of them to simultaneously call.
for i := 0; i < count; i++ {
<-tcc.callC
}
// They're all blocked on errC; allow them to continue.
for i := 0; i < count; i++ {
tcc.errC <- nil
}
// Wait for them to finish.
wg.Wait()
// Confirm that all registered successfully.
So(errors.SingleError(errs), ShouldBeNil)
remotes := 0
for i := 0; i < count; i++ {
if state[i].ProtoVersion == "remote" {
remotes++
}
}
So(remotes, ShouldEqual, count)
})
})
}