MissingFeaturesKernelExpansion.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief A kernel expansion with support of missing features
6 *
7 *
8 *
9 * \author B. Li
10 * \date 2012
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34#include <shark/Data/Dataset.h>
35#include <shark/Data/DataView.h>
38
39namespace shark {
40
41/// \brief Kernel expansion with missing features support
42/// For a choice of kernel, see \ref kernels.
43/// \ingroup models
44template<class InputType>
46{
47private:
49public:
50 typedef typename Base::KernelType KernelType;
53 /// Constructors from the base class
54 ///@{
56
57
61
65 ///@}
66
67 /// \brief From INameable: return the class name.
68 std::string name() const
69 { return "MissingFeaturesKernelExpansion"; }
70
71 boost::shared_ptr<State> createState()const{
72 return boost::shared_ptr<State>(new EmptyState());
73 }
74
75 /// Override eval(...) in the base class
76 virtual void eval(BatchInputType const& patterns, BatchOutputType& outputs)const{
78 SIZE_CHECK(Base::m_alpha.size1() > 0u);
79
80 //Todo: i am too lazy to us iterated loops in this function.
81 //so i am using a DataView to have O(1) random access lookup. but this is not needed!
82 DataView<Data<InputType> const > indexedBasis(Base::m_basis);
83
84 ensure_size(outputs,batchSize(patterns),Base::outputShape().numElements());
85 if (Base::hasOffset())
86 noalias(outputs) = repeat(Base::m_b,batchSize(patterns));
87 else
88 outputs.clear();
89
90 for(std::size_t p = 0; p != batchSize(patterns); ++p){
91
92
93 // Calculate scaling coefficient for the 'pattern'
94 const double patternNorm = computeNorm(column(Base::m_alpha, 0), m_scalingCoefficients, row(patterns,p));
95 const double patternSc = patternNorm / m_classifierNorm;
96
97 // Do normal classification except that we use kernel which supports inputs with Missing features
98 //TODO: evaluate k for all i and replace the += with a matrix-vector operation.
99 //better: do this for all p and i and go matrix-matrix-multiplication
100 for (std::size_t i = 0; i != indexedBasis.size(); ++i){
101 const double k = evalSkipMissingFeatures(
103 indexedBasis[i],
104 row(patterns,p)) / m_scalingCoefficients[i] / patternSc;
105 noalias(row(outputs,p)) += k * row(Base::m_alpha, i);
106
107 }
108 }
109 }
110 void eval(BatchInputType const& patterns, BatchOutputType& outputs, State & state)const{
111 eval(patterns, outputs);
112 }
113
114 /// Calculate norm of classifier, i.e., ||w||
115 ///
116 /// formula:
117 /// \f$ \sum_{i,j=1}^{n}\alpha_i\frac{y_i}{s_i}K\left(x_i,x_j)\right)\frac{y_j}{s_j}\alpha_j \f$
118 /// where \f$ s_i \f$ is scaling coefficient, and \f$ K \f$ is kernel function,
119 /// \f$ K\left(x_i,x_j)\right) \f$ is taken only over features that are valid for both \f$ x_i \f$ and \f$ x_j \f$
120 template<class InputTypeT>
122 const RealVector& alpha,
123 const RealVector& scalingCoefficient,
124 InputTypeT const& missingness
125 ) const{
127 SIZE_CHECK(alpha.size() == scalingCoefficient.size());
128 SIZE_CHECK(Base::m_basis.numberOfElements() == alpha.size());
129
130 // Calculate ||w||^2
131 double norm_sqr = 0.0;
132
133 //Todo: i am too lazy to use iterated loops in this function.
134 //so i am using a DataView to have O(1) random access lookup. but this is not needed!
135 DataView<Data<InputType> const > indexedBasis(Base::m_basis);
136
137 for (std::size_t i = 0; i < alpha.size(); ++i){
138 for (std::size_t j = 0; j < alpha.size(); ++j){
139 const double evalResult = evalSkipMissingFeatures(
141 indexedBasis[i],
142 indexedBasis[j],
143 missingness);
144 // Note that in Shark solver, we do axis flip by substituting \alpha with y \times \alpha
145 norm_sqr += evalResult * alpha(i) * alpha(j) / scalingCoefficient(i) / scalingCoefficient(j);
146 }
147 }
148
149 // Return ||w||
150 return std::sqrt(norm_sqr);
151 }
152
154 const RealVector& alpha,
155 const RealVector& scalingCoefficient
156 ) const{
158 SIZE_CHECK(alpha.size() == scalingCoefficient.size());
159 SIZE_CHECK(Base::m_basis.numberOfElements() == alpha.size());
160
161 //Todo: i am too lazy to us iterated loops in this function.
162 //so i am using a DataView to have O(1) random access lookup. but this is not needed!
163 DataView<Data<InputType> const > indexedBasis(Base::m_basis);
164
165 // Calculate ||w||^2
166 double norm_sqr = 0.0;
167
168 for (std::size_t i = 0; i < alpha.size(); ++i){
169 for (std::size_t j = 0; j < alpha.size(); ++j){
170 const double evalResult = evalSkipMissingFeatures(
172 indexedBasis[i],
173 indexedBasis[j]);
174 // Note that in Shark solver, we do axis flip by substituting \alpha with y \times \alpha
175 norm_sqr += evalResult * alpha(i) * alpha(j) / scalingCoefficient(i) / scalingCoefficient(j);
176 }
177 }
178
179 // Return ||w||
180 return std::sqrt(norm_sqr);
181 }
182
183 void setScalingCoefficients(const RealVector& scalingCoefficients)
184 {
185#if DEBUG
186 for(double v: scalingCoefficients)
187 {
188 SHARK_ASSERT(v > 0.0);
189 }
190#endif
191 m_scalingCoefficients = scalingCoefficients;
192 }
193
194 void setClassifierNorm(double classifierNorm)
195 {
196 SHARK_ASSERT(classifierNorm > 0.0);
197 m_classifierNorm = classifierNorm;
198 }
199
200protected:
201 /// The scaling coefficients
203
204 /// The norm of classifier(w)
206};
207
208} // namespace shark {