Skip to content

Instantly share code, notes, and snippets.

@ronekko
Created December 1, 2017 10:24
Show Gist options
  • Select an option

  • Save ronekko/bbc9191e331642ff6ff846fb00b04a43 to your computer and use it in GitHub Desktop.

Select an option

Save ronekko/bbc9191e331642ff6ff846fb00b04a43 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
Created on Tue May 21 12:50:16 2013
@author: ryuhei
"""
import numpy as np
import matplotlib.pyplot as plt
M = 100
alpha = np.array([5.0, 3.0, 2.0]) # Parameter of Dirichlet distribution
K = len(alpha)
alpha *= 0.1
theta = np.random.dirichlet(alpha, M)
theta_avg = np.zeros(K)
for theta_i in theta:
theta_avg += theta_i
theta_avg /= M
print("alpha: " + str(alpha))
print("theta_avg: " + str(theta_avg))
print("theta[0:10]:\n" + str(theta[0:10]))
fig = plt.figure()
for i in range(10):
for j in range(10):
plt.subplot(10, 10, j*10+i + 1)
plt.pie(theta[j*10+i])
plt.axis('equal')
plt.figure()
plt.axis('equal')
plt.pie(theta_avg)
plt.title("alpha: " + str(alpha) + "\ntheta_avg: " + str(theta_avg))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment