add Each method
diff --git a/bench_test.go b/bench_test.go
index e98bf91..f893d10 100644
--- a/bench_test.go
+++ b/bench_test.go
@@ -501,6 +501,44 @@
benchUnion(b, 100, NewThreadUnsafeSet(), NewThreadUnsafeSet())
}
+func benchEach(b *testing.B, n int, s Set) {
+ nums := nrand(n)
+ for _, v := range nums {
+ s.Add(v)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ s.Each(func(elem interface{}) bool {
+ return false
+ })
+ }
+}
+
+func BenchmarkEach1Safe(b *testing.B) {
+ benchEach(b, 1, NewSet())
+}
+
+func BenchmarkEach1Unsafe(b *testing.B) {
+ benchEach(b, 1, NewThreadUnsafeSet())
+}
+
+func BenchmarkEach10Safe(b *testing.B) {
+ benchEach(b, 10, NewSet())
+}
+
+func BenchmarkEach10Unsafe(b *testing.B) {
+ benchEach(b, 10, NewThreadUnsafeSet())
+}
+
+func BenchmarkEach100Safe(b *testing.B) {
+ benchEach(b, 100, NewSet())
+}
+
+func BenchmarkEach100Unsafe(b *testing.B) {
+ benchEach(b, 100, NewThreadUnsafeSet())
+}
+
func benchIter(b *testing.B, n int, s Set) {
nums := nrand(n)
for _, v := range nums {
diff --git a/set.go b/set.go
index 609e093..7411982 100644
--- a/set.go
+++ b/set.go
@@ -126,6 +126,10 @@
// panic.
IsSuperset(other Set) bool
+ // Iterates over elements and executes the passed func against each element.
+ // If passed func returns true, stop iteration at the time.
+ Each(func(interface{}) bool)
+
// Returns a channel of elements that you can
// range over.
Iter() <-chan interface{}
diff --git a/set_test.go b/set_test.go
index e560edb..4776e2d 100644
--- a/set_test.go
+++ b/set_test.go
@@ -849,6 +849,37 @@
}
}
+func Test_Each(t *testing.T) {
+ a := NewSet()
+
+ a.Add("Z")
+ a.Add("Y")
+ a.Add("X")
+ a.Add("W")
+
+ b := NewSet()
+ a.Each(func(elem interface{}) bool {
+ b.Add(elem)
+ return false
+ })
+
+ if !a.Equal(b) {
+ t.Error("The sets are not equal after iterating (Each) through the first set")
+ }
+
+ var count int
+ a.Each(func(elem interface{}) bool {
+ if count == 2 {
+ return true
+ }
+ count++
+ return false
+ })
+ if count != 2 {
+ t.Error("Iteration should stop on the way")
+ }
+}
+
func Test_Iter(t *testing.T) {
a := NewSet()
diff --git a/threadsafe.go b/threadsafe.go
index d7dd2d2..8dae161 100644
--- a/threadsafe.go
+++ b/threadsafe.go
@@ -151,6 +151,16 @@
return len(set.s)
}
+func (set *threadSafeSet) Each(cb func(interface{}) bool) {
+ set.RLock()
+ for elem := range set.s {
+ if cb(elem) {
+ break
+ }
+ }
+ set.RUnlock()
+}
+
func (set *threadSafeSet) Iter() <-chan interface{} {
ch := make(chan interface{})
go func() {
diff --git a/threadsafe_test.go b/threadsafe_test.go
index 858e59d..5c32fcb 100644
--- a/threadsafe_test.go
+++ b/threadsafe_test.go
@@ -30,6 +30,7 @@
"math/rand"
"runtime"
"sync"
+ "sync/atomic"
"testing"
)
@@ -294,6 +295,35 @@
wg.Wait()
}
+func Test_EachConcurrent(t *testing.T) {
+ runtime.GOMAXPROCS(2)
+ concurrent := 10
+
+ s := NewSet()
+ ints := rand.Perm(N)
+ for _, v := range ints {
+ s.Add(v)
+ }
+
+ var count int64
+ wg := new(sync.WaitGroup)
+ wg.Add(concurrent)
+ for n := 0; n < concurrent; n++ {
+ go func() {
+ defer wg.Done()
+ s.Each(func(elem interface{}) bool {
+ atomic.AddInt64(&count, 1)
+ return false
+ })
+ }()
+ }
+ wg.Wait()
+
+ if count != int64(N*concurrent) {
+ t.Errorf("%v != %v", count, int64(N*concurrent))
+ }
+}
+
func Test_IterConcurrent(t *testing.T) {
runtime.GOMAXPROCS(2)
diff --git a/threadunsafe.go b/threadunsafe.go
index 61a056f..fec2e37 100644
--- a/threadunsafe.go
+++ b/threadunsafe.go
@@ -159,6 +159,14 @@
return len(*set)
}
+func (set *threadUnsafeSet) Each(cb func(interface{}) bool) {
+ for elem := range *set {
+ if cb(elem) {
+ break
+ }
+ }
+}
+
func (set *threadUnsafeSet) Iter() <-chan interface{} {
ch := make(chan interface{})
go func() {