且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

python scipy/numpy中的多项式pmf

更新时间:2021-09-16 01:47:43

我不知道内置函数,并且二项式概率也无法推广(您需要对另一组可能的结果进行归一化,因为所有计数的总和必须为n,独立的二项式将不会处理).但是,实现自己非常简单,例如:

There's no built-in function that I know of, and the binomial probabilities do not generalize (you need to normalise over a different set of possible outcomes, since the sum of all the counts must be n which won't be taken care of by independent binomials). However, it's fairly straightforward to implement yourself, for example:

import math

class Multinomial(object):
  def __init__(self, params):
    self._params = params

  def pmf(self, counts):
    if not(len(counts)==len(self._params)):
      raise ValueError("Dimensionality of count vector is incorrect")

    prob = 1.
    for i,c in enumerate(counts):
      prob *= self._params[i]**counts[i]

    return prob * math.exp(self._log_multinomial_coeff(counts))

  def log_pmf(self,counts):
    if not(len(counts)==len(self._params)):
      raise ValueError("Dimensionality of count vector is incorrect")

    prob = 0.
    for i,c in enumerate(counts):
      prob += counts[i]*math.log(self._params[i])

    return prob + self._log_multinomial_coeff(counts)

  def _log_multinomial_coeff(self, counts):
    return self._log_factorial(sum(counts)) - sum(self._log_factorial(c)
                                                    for c in counts)

  def _log_factorial(self, num):
    if not round(num)==num and num > 0:
      raise ValueError("Can only compute the factorial of positive ints")
    return sum(math.log(n) for n in range(1,num+1))

m = Multinomial([0.1, 0.1, 0.8])
print m.pmf([4,4,2])

>>2.016e-05

我对多项式系数的实现有些天真,并且在对数空间中工作以防止溢出.还应注意,n是多余的参数,因为它由计数的总和给出(并且相同的参数集适用于任何n).此外,由于中等n或较大维数会迅速下溢,因此***在日志空间中工作(这里也提供了logPMF!)

My implementation of the multinomial coefficient is somewhat naive, and works in log space to prevent overflow. Also be aware that n is superfluous as a parameter, since it's given by the sum of the counts (and the same parameter set works for any n). Furthermore, since this will quickly underflow for moderate n or large dimensionality, you're better working in log space (logPMF provided here too!)