且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

python中的加权随机样本

更新时间:2022-02-16 22:45:59

来自您的代码:..

weight_sample_indexes = lambda weights, k: random.sample([val 
        for val, cnt in enumerate(weights) for i in range(cnt)], k)

.. 我假设权重是正整数,没有替换"是指没有替换解开的序列.

.. I assume that weights are positive integers and by "without replacement" you mean without replacement for the unraveled sequence.

这是一个基于 random.sample 和 O(log n) __getitem__ 的解决方案:

Here's a solution based on random.sample and O(log n) __getitem__:

import bisect
import random
from collections import Counter, Sequence

def weighted_sample(population, weights, k):
    return random.sample(WeightedPopulation(population, weights), k)

class WeightedPopulation(Sequence):
    def __init__(self, population, weights):
        assert len(population) == len(weights) > 0
        self.population = population
        self.cumweights = []
        cumsum = 0 # compute cumulative weight
        for w in weights:
            cumsum += w   
            self.cumweights.append(cumsum)  
    def __len__(self):
        return self.cumweights[-1]
    def __getitem__(self, i):
        if not 0 <= i < len(self):
            raise IndexError(i)
        return self.population[bisect.bisect(self.cumweights, i)]

示例

total = Counter()
for _ in range(1000):
    sample = weighted_sample("abc", [1,10,2], 5)
    total.update(sample)
print(sample)
print("Frequences %s" % (dict(Counter(sample)),))

# Check that values are sane
print("Total " + ', '.join("%s: %.0f" % (val, count * 1.0 / min(total.values()))
                           for val, count in total.most_common()))

输出

['b', 'b', 'b', 'c', 'c']
Frequences {'c': 2, 'b': 3}
Total b: 10, c: 2, a: 1