blob: 79368eb2d899614ccedb3d02f9bedf88dd4f3cdc [file] [log] [blame]
// Copyright 2020 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package partition encapsulates partitioning and querying large keyspace which
// can't be expressed even as uint64.
//
// All to/from string functions use hex encoding.
package partition
import (
"container/list"
"fmt"
"math/big"
"sort"
"strings"
"go.chromium.org/luci/common/errors"
)
// Partition represents a range [Low..High).
type Partition struct {
Low big.Int // inclusive
High big.Int // exclusive. May be equal to max SHA2 hash value + 1.
}
// SortedPartitions are disjoint partitions sorted by ascending .Low field.
type SortedPartitions []*Partition
func FromInts(low, high int64) *Partition {
if low > high {
panic(errors.Reason("Partition %d..%d is invalid", low, high))
}
p := &Partition{}
p.Low.SetInt64(low)
p.High.SetInt64(high)
return p
}
func SpanInclusive(low, highInclusive string) (*Partition, error) {
p := &Partition{}
if err := setBigIntFromString(&p.Low, low); err != nil {
return nil, err
}
if err := setBigIntFromString(&p.High, highInclusive); err != nil {
return nil, err
}
p.High.Add(&p.High, bigInt1) // s.high++
if p.Low.Cmp(&p.High) > 0 {
return nil, errors.Reason("Partition %s is invalid", p.String()).Err()
}
return p, nil
}
func Universe(keySpaceBytes int) *Partition {
p := &Partition{}
p.High.SetBit(&p.High, keySpaceBytes*8, 1) // 2^(keySpaceBytes*8)
return p
}
func FromString(s string) (*Partition, error) {
i := strings.Index(s, "_")
if i <= 0 || i == len(s)-1 {
return nil, errors.Reason("partition %q has invalid format", s).Err()
}
p := &Partition{}
if err := setBigIntFromString(&p.Low, s[:i]); err != nil {
return nil, err
}
if err := setBigIntFromString(&p.High, s[i+1:]); err != nil {
return nil, err
}
if p.Low.Cmp(&p.High) > 0 {
return nil, errors.Reason("Partition %s is invalid", p.String()).Err()
}
return p, nil
}
func (p Partition) String() string {
return fmt.Sprintf("%s_%s", p.Low.Text(16 /*hex*/), p.High.Text(16 /*hex*/))
}
func (p Partition) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`"%s_%s"`, p.Low.Text(16 /*hex*/), p.High.Text(16 /*hex*/))), nil
}
func (p *Partition) UnmarshalJSON(bs []byte) error {
s := string(bs)
switch {
case s == `null`:
return nil
case len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"':
return errors.Reason("invalid JSON-serialized partition %q", s).Err()
default:
if tmp, err := FromString(s[1 : len(s)-1]); err != nil {
return err
} else {
*p = *tmp
return nil
}
}
}
func (p Partition) Copy() *Partition {
r := &Partition{}
r.Low.Set(&p.Low)
r.High.Set(&p.High)
return r
}
func (p Partition) QueryBounds(keySpaceBytes int) (low, high string) {
low = paddedHex(&p.Low, keySpaceBytes)
if !inKeySpace(&p.High, keySpaceBytes) {
// In practice, this should mean p.high == 2^(keySpaceBytes*8).
high = "g" // all hex strings are smaller than "g".
} else {
high = paddedHex(&p.High, keySpaceBytes)
}
return
}
func (p Partition) Split(shards int) SortedPartitions {
if shards <= 0 {
panic(">=1 shard required")
}
var increment, remainder, cur big.Int
increment.QuoRem(
cur.Sub(&p.High, &p.Low),
big.NewInt(int64(shards)),
&remainder)
if remainder.Cmp(bigInt0) > 0 {
increment.Add(&increment, bigInt1)
}
partitions := make([]*Partition, 0, shards)
cur.Set(&p.Low)
for cur.Cmp(&p.High) < 0 {
next := &Partition{}
next.Low.Set(&cur)
next.High.Add(&cur, &increment)
cur.Set(&next.High)
partitions = append(partitions, next)
}
// Due to int division to compute the increment, the last partition may
// overshoot, so ensure it ends exactly at the end of the original.
partitions[len(partitions)-1].High = p.High
return partitions
}
// EducatedSplitAfter splits partition after a given boundary assuming constant
// density s.t. each shard has approximately targetItems.
//
// Caps the number of resulting partitions to at most maxShards.
// panics if called on invalid data.
func (p Partition) EducatedSplitAfter(exclusive string, beforeItems, targetItems, maxShards int) SortedPartitions {
remaining := Partition{}
if err := setBigIntFromString(&remaining.Low, exclusive); err != nil {
panic(err)
}
if p.Low.Cmp(&remaining.Low) > 0 { // low > remaining.Low
panic("must be within the partition")
}
if p.High.Cmp(&remaining.Low) <= 0 { // high <= remaining.Low
panic("must be within the partition")
}
remaining.Low.Add(&remaining.Low, bigInt1) // remaining.Low++
remaining.High.Set(&p.High)
// Compute expShards as
//
// beforeItems / len(before) * len(remaining) / targetItems
//
// in a somewhat readable way as
//
// (beforeItems * len(remaining)) / ( targetItems * len(before))
//
// NOTE: this can be optimized if needed to avoid excessive memory allocations
// in bit.Int at the cost of readability.
iBefore := big.NewInt(int64(beforeItems))
iTarget := big.NewInt(int64(targetItems))
var expShards, iRemainder big.Int
expShards.QuoRem(
(&big.Int{}).Mul(iBefore, distance(&remaining.Low, &remaining.High)),
(&big.Int{}).Mul(iTarget, distance(&p.Low, &remaining.Low)),
&iRemainder,
)
if iRemainder.Cmp(bigInt0) > 0 {
expShards.Add(&expShards, bigInt1)
}
shards := maxShards
if expShards.Cmp(big.NewInt(int64(maxShards))) < 0 {
shards = int(expShards.Int64())
}
return remaining.Split(shards)
}
// SortedPartitionsBuilder constructs a sequence of partitions by excluding
// chunks from a starting partion.
//
// Not intended to scale to large number of exclusion operations.
type SortedPartitionsBuilder struct {
// l holds partitions in sorted order, leading to O(len(l)) runtime of the
// Exclude().
//
// For max performance with >~20 exclusions, an interval tree should be used
// instead. Unfortunately, due to lack of generics in Go, most interval tree
// libraries expect float64 or int64 nounds, not big.Int.
l *list.List
}
func NewSortedPartitionsBuilder(p *Partition) SortedPartitionsBuilder {
b := SortedPartitionsBuilder{l: list.New()}
b.l.PushBack(p.Copy())
return b
}
func (b *SortedPartitionsBuilder) IsEmpty() bool {
return b.l.Len() == 0
}
func (b *SortedPartitionsBuilder) Result() SortedPartitions {
r := make([]*Partition, 0, b.l.Len())
for el := b.l.Front(); el != nil; el = el.Next() {
r = append(r, el.Value.(*Partition))
}
return r
}
func (b *SortedPartitionsBuilder) Exclude(exclude *Partition) {
for el := b.l.Front(); el != nil; {
avail := el.Value.(*Partition)
switch {
case exclude.Low.Cmp(&avail.High) >= 0:
// avail < exclude
el = el.Next()
case exclude.High.Cmp(&avail.Low) <= 0:
// exclude < avail
return
case exclude.Low.Cmp(&avail.Low) <= 0:
// front excluded
if exclude.High.Cmp(&avail.High) >= 0 {
// back also excluded
next := el.Next()
b.l.Remove(el)
el = next
} else {
// only back remains.
avail.Low.Set(&exclude.High)
return
}
case exclude.High.Cmp(&avail.High) >= 0:
// only front remains.
avail.High.Set(&exclude.Low)
el = el.Next()
default:
// middle is excluded.
second := &Partition{}
second.Low.Set(&exclude.High)
second.High.Set(&avail.High)
avail.High.Set(&exclude.Low)
b.l.InsertAfter(second, el)
return
}
}
}
// OnlyIn efficiently returns a subsequence of the `n` sorted by key objects
// whose key belongs to one of the partitions.
//
// Calls use(i,j) for each objects[i:j] which belong to the range.
func (ps SortedPartitions) OnlyIn(n int, key func(i int) string, use func(l, h int), keySpaceBytes int) {
k := 0
// Remaining slice is [k..n)
for len(ps) > 0 && k < n {
lowStr, highStr := ps[0].QueryBounds(keySpaceBytes)
fr := sort.Search(n-k, func(i int) bool { return key(k+i) >= lowStr })
if fr == n-k {
return
}
to := sort.Search(n-k-fr, func(i int) bool { return key(fr+k+i) >= highStr })
if to > 0 {
use(fr+k, k+fr+to)
}
// Can be optimized more by doing binary search over `ps` if fr == to == 0.
k = k + fr + to
ps = ps[1:]
}
}
// helpers
var (
// these are effectively constants predefined to avoid needless memory allocations.
bigInt0 = big.NewInt(0)
bigInt1 = big.NewInt(1)
)
func distance(low, high *big.Int) *big.Int {
return (&big.Int{}).Sub(high, low)
}
func setBigIntFromString(b *big.Int, s string) error {
if _, ok := b.SetString(s, 16 /*hex*/); !ok {
return errors.Reason("invalid bigint hex %q", s).Err()
}
if b.Sign() == -1 {
return errors.Reason("negative value %q not allowed", s).Err()
}
return nil
}
func paddedHex(b *big.Int, keySpaceBytes int) string {
s := b.Text(16 /*hex*/)
return strings.Repeat("0", keySpaceBytes*2-len(s)) + s
}
// inKeySpace returns whether v does not exceed keyspace upper boundary.
func inKeySpace(v *big.Int, keySpaceBytes int) bool {
return v.BitLen() <= keySpaceBytes*8
}