% L6Bdirichlet Demo of Dirichlet distribution -- with animation % % This time we are working with a k-sided (possibly loaded) die. % The parameter vector contains k probability parameters, % and its distribution is a k-dimensional Dirichlet distribution. % % Visualizations are one-dimensional projections because we % don't have a high-dimensional display. The upper panel shows % posterior density function of each parameter (separately). % The lower panel shows 95% credible intervals for each parameter. % % Input is a sequence of die rolls (integers from 1 to k). % Alternatively, if no input (or empty input) is given, % then the user is prompted for the rolls, one at a time. function [postpara,X]=L6Bdirichlet(X, k, plot_every) askuser=false; if nargin<1, X=[]; end if isempty(X), askuser=true; end if nargin<2, k=6; end if nargin<3, plot_every=1; end if askuser nmax = 1e6; else nmax = length(X); end % Prior parameters for Dirichlet: % Dirichlet(1,1,...,1) is the uniform distribution over the % possible region. priorpara = ones(1,k); postpara = priorpara; seen = zeros(1,k); % Fine grid of values for a single parameter, for plotting gridsize = 1000; u = linspace(0,1,gridsize)'; clf if askuser fprintf('Input roll results (between 1 and %d) one at a time, or 0 to terminate.\n', k); end leg = [repmat('P',k,1) num2str((1:k)')]; for n=0:nmax if n>0 % One more observation => increase one of the parameters. postpara(X(n)) = postpara(X(n))+1; seen(X(n)) = seen(X(n))+1; end if mod(n,plot_every)==0 % Marginal densities of each of the six parameters f = zeros(gridsize, k); legendtext = {}; for i=1:k % Marginal distribution of i'th parameter is % beta distribution that lumps together the % other parameters as the "other" choice. a = postpara(i); b = sum(postpara) - a; f(:,i) = betapdf(u, a, b); legendtext{i} = sprintf('%d (seen %d times)', i, seen(i)); end subplot(2,1,1) cla plot(u,f, 'linewidth', 2); title(sprintf('posterior densities after %d observations', n)); xlabel('parameter value'); ylabel('density'); legend(legendtext) grid on subplot(2,1,2) cla for i=1:k a = postpara(i); b = sum(postpara) - a; credint = betainv([0.025 0.975], a, b); plot(credint, [i i], '.-', 'linewidth', 2, ... 'markersize', 20); text(mean(credint), i, sprintf('%d',i), ... 'VerticalAlignment','top','HorizontalAlignment','center') hold on end set(gca, 'xlim', [0 1]); set(gca, 'ylim', [0.5 k+0.5]); set(gca, 'ydir', 'reverse') set(gca, 'ytick', []) title('95% credible intervals') grid on if not(askuser) && nk || x~=round(x) x = input('invalid input\nnext roll: '); end if x==0, break; end X = [X x]; end end