This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of Weighted Random List (#39)
- Loading branch information
1 parent
bba4422
commit dc2f954
Showing
3 changed files
with
472 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
package random | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"math/rand" | ||
"sort" | ||
"time" | ||
|
||
"github.com/lyft/flytestdlib/logger" | ||
) | ||
|
||
//go:generate mockery -all -case=underscore | ||
|
||
// Interface to use the Weighted Random | ||
type WeightedRandomList interface { | ||
Get() Comparable | ||
GetWithSeed(seed rand.Source) (Comparable, error) | ||
List() []Comparable | ||
Len() int | ||
} | ||
|
||
// Interface for items that can be used along with WeightedRandomList | ||
type Comparable interface { | ||
Compare(to Comparable) bool | ||
} | ||
|
||
// Structure of each entry to select from | ||
type Entry struct { | ||
Item Comparable | ||
Weight float32 | ||
} | ||
|
||
type internalEntry struct { | ||
entry Entry | ||
currentTotal float32 | ||
} | ||
|
||
// WeightedRandomList selects elements randomly from the list taking into account individual weights. | ||
// Weight has to be assigned between 0 and 1. | ||
// Support deterministic results when given a particular seed source | ||
type weightedRandomListImpl struct { | ||
entries []internalEntry | ||
totalWeight float32 | ||
} | ||
|
||
func validateEntries(entries []Entry) error { | ||
if len(entries) == 0 { | ||
return fmt.Errorf("entries is empty") | ||
} | ||
for index, entry := range entries { | ||
if entry.Item == nil { | ||
return fmt.Errorf("invalid entry: nil, index %d", index) | ||
} | ||
if entry.Weight < 0 || entry.Weight > float32(1) { | ||
return fmt.Errorf("invalid weight %f, index %d", entry.Weight, index) | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
// Given a list of entries with weights, returns WeightedRandomList | ||
func NewWeightedRandom(ctx context.Context, entries []Entry) (WeightedRandomList, error) { | ||
err := validateEntries(entries) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
sort.Slice(entries, func(i, j int) bool { | ||
return entries[i].Item.Compare(entries[j].Item) | ||
}) | ||
var internalEntries []internalEntry | ||
numberOfEntries := len(entries) | ||
totalWeight := float32(0) | ||
for _, e := range entries { | ||
totalWeight += e.Weight | ||
} | ||
|
||
currentTotal := float32(0) | ||
for _, e := range entries { | ||
if totalWeight == 0 { | ||
// This indicates that none of the entries have weight assigned. | ||
// We will assign equal weights to everyone | ||
currentTotal += 1.0 / float32(numberOfEntries) | ||
} else if e.Weight == 0 { | ||
// Entries which have zero weight are ignored | ||
logger.Debug(ctx, "ignoring entry due to empty weight %v", e) | ||
continue | ||
} | ||
|
||
currentTotal += e.Weight | ||
internalEntries = append(internalEntries, internalEntry{ | ||
entry: e, | ||
currentTotal: currentTotal, | ||
}) | ||
} | ||
|
||
return &weightedRandomListImpl{ | ||
entries: internalEntries, | ||
totalWeight: currentTotal, | ||
}, nil | ||
} | ||
|
||
func (w *weightedRandomListImpl) get(generator *rand.Rand) Comparable { | ||
randomWeight := generator.Float32() * w.totalWeight | ||
for _, e := range w.entries { | ||
if e.currentTotal >= randomWeight && e.currentTotal > 0 { | ||
return e.entry.Item | ||
} | ||
} | ||
return w.entries[len(w.entries)-1].entry.Item | ||
} | ||
|
||
// Returns a random entry based on the weights | ||
func (w *weightedRandomListImpl) Get() Comparable { | ||
randGenerator := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) | ||
return w.get(randGenerator) | ||
} | ||
|
||
// For a given seed, the same entry will be returned all the time. | ||
func (w *weightedRandomListImpl) GetWithSeed(seed rand.Source) (Comparable, error) { | ||
randGenerator := rand.New(seed) | ||
return w.get(randGenerator), nil | ||
} | ||
|
||
// Lists all the entries that are eligible for selection | ||
func (w *weightedRandomListImpl) List() []Comparable { | ||
entries := make([]Comparable, len(w.entries)) | ||
for index, indexedItem := range w.entries { | ||
entries[index] = indexedItem.entry.Item | ||
} | ||
return entries | ||
} | ||
|
||
// Gets the number of items that are being considered for selection. | ||
func (w *weightedRandomListImpl) Len() int { | ||
return len(w.entries) | ||
} |
Oops, something went wrong.