Previous topic

numpy.random.chisquare

Next topic

numpy.random.exponential

This is documentation for an old release of NumPy (version 1.15). Read this page in the documentation of the latest stable release (version 2.2).

numpy.random.dirichlet

numpy.random.dirichlet(alpha, size=None)

Draw samples from the Dirichlet distribution.

Draw size samples of dimension k from a Dirichlet distribution. A Dirichlet-distributed random variable can be seen as a multivariate generalization of a Beta distribution. Dirichlet pdf is the conjugate prior of a multinomial in Bayesian inference.

Parameters:
alpha : array

Parameter of the distribution (k dimension for sample of dimension k).

size : int or tuple of ints, optional

Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples are drawn. Default is None, in which case a single value is returned.

Returns:
samples : ndarray,

The drawn samples, of shape (size, alpha.ndim).

Raises:
ValueError

If any value in alpha is less than or equal to zero

Notes

X \approx \prod_{i=1}^{k}{x^{\alpha_i-1}_i}

Uses the following property for computation: for each dimension, draw a random sample y_i from a standard gamma generator of shape alpha_i, then X = \frac{1}{\sum_{i=1}^k{y_i}} (y_1, \ldots, y_n) is Dirichlet distributed.

References

[1]David McKay, “Information Theory, Inference and Learning Algorithms,” chapter 23, http://www.inference.phy.cam.ac.uk/mackay/
[2]Wikipedia, “Dirichlet distribution”, http://en.wikipedia.org/wiki/Dirichlet_distribution

Examples

Taking an example cited in Wikipedia, this distribution can be used if one wanted to cut strings (each of initial length 1.0) into K pieces with different lengths, where each piece had, on average, a designated average length, but allowing some variation in the relative sizes of the pieces.

>>>
>>> s = np.random.dirichlet((10, 5, 3), 20).transpose()
>>>
>>> plt.barh(range(20), s[0])
>>> plt.barh(range(20), s[1], left=s[0], color='g')
>>> plt.barh(range(20), s[2], left=s[0]+s[1], color='r')
>>> plt.title("Lengths of Strings")