MultiNomialDistribution.h
Go to the documentation of this file.
1/*!
2 *
3 *
4 * \brief Implements a multinomial distribution
5 *
6 *
7 *
8 * \author O.Krause
9 * \date 2016
10 *
11 *
12 * \par Copyright 1995-2017 Shark Development Team
13 *
14 * <BR><HR>
15 * This file is part of Shark.
16 * <https://shark-ml.github.io/Shark/>
17 *
18 * Shark is free software: you can redistribute it and/or modify
19 * it under the terms of the GNU Lesser General Public License as published
20 * by the Free Software Foundation, either version 3 of the License, or
21 * (at your option) any later version.
22 *
23 * Shark is distributed in the hope that it will be useful,
24 * but WITHOUT ANY WARRANTY; without even the implied warranty of
25 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26 * GNU Lesser General Public License for more details.
27 *
28 * You should have received a copy of the GNU Lesser General Public License
29 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30 *
31 */
32#ifndef SHARK_STATISTICS_MULTINOMIALDISTRIBUTION_H
33#define SHARK_STATISTICS_MULTINOMIALDISTRIBUTION_H
34
35#include <shark/LinAlg/Base.h>
36#include <shark/Core/Random.h>
37
38namespace shark {
39
40/// \brief Implements a multinomial distribution.
41///
42/// A multinomial distribution is a discrete distribution with states 0,...,N-1
43/// and probabilities p_i for state i with sum_i p_i = 1. This implementation uses
44/// the fast alias method (Kronmal and Peterson,1979) to draw the numbers in
45/// constant time. Setup is O(N) and also quite fast. It is advisable
46/// to use this method to draw many numbers in succession.
47///
48/// The idea of the alias method is to pair a state with high probability with a state with low
49/// probability. A high probability state can in this case be included in several pairs. To draw,
50/// first one of the states is selected and afterwards a coin toss decides which element of the pair
51/// is taken.
53public:
54 typedef std::size_t result_type;
55
57
58 /// \brief Constructor
59 /// \param [in] probabilities Probability vector
61 : m_probabilities(probabilities){
62 update();
63 }
64
65 /// \brief Stores/Restores the distribution from the supplied archive.
66 /// \param [in,out] ar The archive to read from/write to.
67 /// \param [in] version Currently unused.
68 template<typename Archive>
69 void serialize( Archive & ar, const unsigned int version ) {
70 ar & BOOST_SERIALIZATION_NVP( m_probabilities );
71 ar & BOOST_SERIALIZATION_NVP( m_q );
72 ar & BOOST_SERIALIZATION_NVP( m_J );
73 }
74
75 /// \brief Accesses the probabilityvector defining the distribution.
76 RealVector const& probabilities() const {
77 return m_probabilities;
78 }
79
80 /// \brief Accesses a mutable reference to the probability vector
81 /// defining the distribution. Allows for l-value semantics.
82 ///
83 /// ATTENTION: If the reference is altered, update needs to be called manually.
84 RealVector& probabilities() {
85 return m_probabilities;
86 }
87
88 /// \brief Samples the distribution.
89 template<class randomType>
90 result_type operator()(randomType& rng) const {
91 std::size_t numStates = m_probabilities.size();
92
93 std::size_t index = random::discrete(rng,std::size_t(0),numStates-1);
94
95 if(random::coinToss(rng, m_q[index]))
96 return index;
97 else
98 return m_J[index];
99 }
100
101
102 void update() {
103 std::size_t numStates = m_probabilities.size();
104 m_q.resize(numStates);
105 m_J.resize(numStates);
106 m_probabilities/=sum(m_probabilities);
107
108 // Sort the data into the outcomes with probabilities
109 // that are larger and smaller than 1/K.
110 std::deque<std::size_t> smaller;
111 std::deque<std::size_t> larger;
112 for(std::size_t i = 0;i != numStates; ++i){
113 m_q(i) = numStates*m_probabilities(i);
114 if(m_q(i) < 1.0)
115 smaller.push_back(i);
116 else
117 larger.push_back(i);
118 }
119 // Loop though and create little binary mixtures that
120 // appropriately allocate the larger outcomes over the
121 // overall uniform mixture.
122 while(!smaller.empty() && !larger.empty()){
123 std::size_t smallIndex = smaller.front();
124 std::size_t largeIndex = larger.front();
125 smaller.pop_front();
126 larger.pop_front();
127
128 m_J[smallIndex] = largeIndex;
129 m_q[largeIndex] -= 1.0 - m_q[smallIndex];
130
131 if(m_q[largeIndex] < 1.0)
132 smaller.push_back(largeIndex);
133 else
134 larger.push_back(largeIndex);
135 }
136 for(std::size_t i = 0; i != larger.size(); ++i){
137 m_q[larger[i]]=std::min(m_q[larger[i]],1.0);
138 }
139 }
140
141private:
142 RealVector m_probabilities; ///< probability of every state.
143 RealVector m_q; ///< probability of the pair (i,J[i]) to draw an.
144 blas::vector<std::size_t> m_J; ///< defines the second element of the pair (i,J[i])
145};
146}
147
148#endif