113 RealMatrix
const& hiddenState,
114 RealMatrix
const& visibleInput,
115 BetaVector
const& beta
117 SIZE_CHECK(hiddenState.size1()==visibleInput.size1());
119 std::size_t
batchSize = hiddenState.size1();
122 RealVector energyTerms = m_hiddenNeurons.energyTerm(hiddenState,beta);
126 for(std::size_t i = 0; i !=
batchSize; ++i){
127 p(i) = m_visibleNeurons.logMarginalize(row(visibleInput,i),beta(i))+energyTerms(i);
142 RealMatrix
const& visibleState,
143 RealMatrix
const& hiddenInput,
144 BetaVector
const& beta
146 SIZE_CHECK(visibleState.size1()==hiddenInput.size1());
147 SIZE_CHECK(visibleState.size1()==beta.size());
148 std::size_t
batchSize = visibleState.size1();
151 RealVector energyTerms = m_visibleNeurons.energyTerm(visibleState,beta);
154 for(std::size_t i = 0; i !=
batchSize; ++i){
155 p(i) = m_hiddenNeurons.logMarginalize(row(hiddenInput,i),beta(i))+energyTerms(i);
167 SIZE_CHECK(visibleStates.size1() == beta.size());
169 RealMatrix hiddenInputs(beta.size(),m_hiddenNeurons.size());
180 SIZE_CHECK(hiddenStates.size1() == beta.size());
182 RealMatrix visibleInputs(beta.size(),m_visibleNeurons.size());
193 RealMatrix
const& hiddenInput,
194 RealMatrix
const& hidden,
195 RealMatrix
const& visible
197 RealMatrix
const& phiOfH = m_hiddenNeurons.phi(hidden);
198 std::size_t
batchSize = hiddenInput.size1();
200 for(std::size_t i = 0; i !=
batchSize; ++i){
201 energies(i) = -inner_prod(row(hiddenInput,i),row(phiOfH,i));
203 energies -= m_hiddenNeurons.energyTerm(hidden,blas::repeat(1.0,
batchSize));
204 energies -= m_visibleNeurons.energyTerm(visible,blas::repeat(1.0,
batchSize));
215 RealMatrix
const& visibleInput,
216 RealMatrix
const& hidden,
217 RealMatrix
const& visible
219 RealMatrix
const& phiOfV = m_visibleNeurons.phi(visible);
220 std::size_t
batchSize = visibleInput.size1();
222 for(std::size_t i = 0; i !=
batchSize; ++i){
223 energies(i) = -inner_prod(row(phiOfV,i),row(visibleInput,i));
225 energies -= m_hiddenNeurons.energyTerm(hidden,blas::repeat(1.0,
batchSize));
226 energies -= m_visibleNeurons.energyTerm(visible,blas::repeat(1.0,
batchSize));