CARTree.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Cart Classifier
6 *
7 *
8 *
9 * \author K. N. Hansen, J. Kremer
10 * \date 2012
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35#ifndef SHARK_MODELS_TREES_CARTree_H
36#define SHARK_MODELS_TREES_CARTree_H
37
38
40#include <shark/Data/Dataset.h>
41namespace shark {
42
43
44///
45/// \brief Classification and Regression Tree.
46///
47/// \par
48/// The CARTree predicts a class label using a decision tree
49/// \ingroup models
50template<class LabelType>
51class CARTree : public AbstractModel<RealVector,LabelType>
52{
53private:
55public:
58
59 struct Node{
60 std::size_t attributeIndex;
62 std::size_t leftId;
63 std::size_t rightIdOrIndex;
64
65 template<class Archive>
66 void serialize(Archive & ar, const unsigned int version){
67 ar & attributeIndex;
68 ar & attributeValue;
69 ar & leftId;
70 ar & rightIdOrIndex;///< either id of right node or index to label array
71 }
72 };
73 typedef std::vector<Node> TreeType;
74
75
76 /// Constructor
77 CARTree(): m_inputDimension(0){}
78
80 : m_inputDimension(inputDimension)
81 , m_outputShape(outputShape){}
82
83
84 /// \brief From INameable: return the class name.
85 std::string name() const
86 { return "CARTree"; }
87
88 boost::shared_ptr<State> createState() const{
89 return boost::shared_ptr<State>(new EmptyState());
90 }
91
92 using base_type::eval;
93 /// \brief Evaluate the Tree on a batch of patterns
94 void eval(BatchInputType const& patterns, BatchOutputType & outputs) const{
95 std::size_t numPatterns = patterns.size1();
96 //evaluate the first pattern alone and create the batch output from that
97 LabelType const& firstResult = evalPattern(row(patterns,0));
98 outputs = Batch<LabelType>::createBatch(firstResult,numPatterns);
99 getBatchElement(outputs,0) = firstResult;
100
101 //evaluate the rest
102 for(std::size_t i = 0; i != numPatterns; ++i){
103 getBatchElement(outputs,i) = evalPattern(row(patterns,i));
104 }
105 }
106
107 void eval(BatchInputType const& patterns, BatchOutputType & outputs, State& state) const{
108 eval(patterns,outputs);
109 }
110 /// \brief Evaluate the Tree on a single pattern
111 void eval(RealVector const& pattern, LabelType& output){
112 output = evalPattern(pattern);
113 }
114
115 /// \brief The model does not have any parameters.
116 std::size_t numberOfParameters() const{
117 return 0;
118 }
119
120 /// \brief The model does not have any parameters.
121 RealVector parameterVector() const {
122 return RealVector();
123 }
124
125 /// \brief The model does not have any parameters.
126 void setParameterVector(RealVector const& param) {
127 SHARK_ASSERT(param.size() == 0);
128 }
129
130 /// from ISerializable, reads a model from an archive
131 void read(InArchive& archive){
132 archive >> m_tree;
133 archive >> m_labels;
134 archive >> m_inputDimension;
135 archive >> m_outputShape;
136 }
137
138 /// from ISerializable, writes a model to an archive
139 void write(OutArchive& archive) const {
140 archive << m_tree;
141 archive << m_labels;
142 archive << m_inputDimension;
143 archive << m_outputShape;
144 }
145
146 //Count how often attributes are used
147 UIntVector countAttributes() const {
148 SHARK_ASSERT(m_inputDimension > 0);
149 UIntVector r(m_inputDimension, 0);
150 for(auto it = m_tree.begin(); it != m_tree.end(); ++it) {
151 if(it->leftId != 0) { // not a label
152 r(it->attributeIndex)++;
153 }
154 }
155 return r;
156 }
157
158 ///Return input dimension
160 return m_inputDimension;
161 }
163 return m_outputShape;
164 }
165
166 ////////////////////////////////
167 /////Tree Construction routines
168 ///////////////////////////////
169
170 std::size_t numberOfNodes() const{
171 return m_tree.size();
172 }
173
174 /// \brief Returns the node with id nodeId
175 Node& getNode(std::size_t nodeId){
176 SIZE_CHECK(nodeId < m_tree.size());
177 return m_tree[nodeId];
178 }
179 /// \brief Returns the node with id nodeId
180 Node const& getNode(std::size_t nodeId)const{
181 SIZE_CHECK(nodeId < m_tree.size());
182 return m_tree[nodeId];
183 }
184
185 LabelType const& getLabel(std::size_t nodeId)const{
186 SIZE_CHECK(nodeId < m_tree.size());
187 return m_labels[m_tree[nodeId].rightIdOrIndex];
188 }
189
190 /// \brief Creates and returns an untyped root node (neither internal, nor leaf node)
192 m_tree.clear();
193 Node root;
194 root.leftId = 0;
195 root.rightIdOrIndex = 0;
196 m_tree.push_back(root);
197 return m_tree[0];
198 }
199
200
201 ///\brief Transforms an untyped node (no child, no internal node) into an internal node
202 ///
203 /// This creates already the two childs of the node, which are untyped.
204 Node& transformInternalNode(std::size_t nodeId, std::size_t attributeIndex, double attributeValue) {
205 // ids for new child nodes
206 int nodeIdLeft = m_tree.size();
207 int nodeIdRight = m_tree.size() + 1;
208
209 //create new child nodes
210 Node leftChild;
211 leftChild.leftId = 0;
212 leftChild.rightIdOrIndex = 0;
213
214 Node rightChild;
215 rightChild.leftId = 0;
216 rightChild.rightIdOrIndex = 0;
217
218 m_tree.push_back(leftChild);
219 m_tree.push_back(rightChild);
220
221 // connect the parent node with its two childs
222 m_tree[nodeId].leftId = nodeIdLeft;
223 m_tree[nodeId].rightIdOrIndex = nodeIdRight;
224 m_tree[nodeId].attributeIndex = attributeIndex;
225 m_tree[nodeId].attributeValue = attributeValue;
226
227 return m_tree[nodeId];
228 }
229 ///\brief Transforms a node (no leaf) into a leaf node and inserts the appropriate label
230 ///
231 /// If the node was an internal node before, its connections get removed and the childs
232 /// are not reachable any more. Calling a reorder routine like reorderBFS() will get rid of those
233 /// nodes.
234 Node& transformLeafNode(std::size_t nodeId, LabelType const& label){
235 Node& node = m_tree[nodeId];
236 node.attributeIndex = 0;
237 node. attributeValue = 0.0;
238 node.leftId = 0;
239 node.rightIdOrIndex = m_labels.size();
240 m_labels.push_back(label);
241 return node;
242 }
243
244 /// \brief Reorders a tree into a breath-first-ordering
245 ///
246 /// This function call will remove all unreachable subtrees while reordering
247 /// the nodes by their depth in the tree, i.e. first comes the root, the the children
248 /// of the root, than their children, etc.
250 TreeType reordered_tree;
251 reordered_tree.reserve(m_tree.size());
252
253 std::deque<std::size_t > bfs_queue;
254 bfs_queue.push_back(0);
255
256 std::size_t nodeId = 0; //running id of the next node to insert
257 while(!bfs_queue.empty()){
258 Node const& node = getNode(bfs_queue.front());
259 bfs_queue.pop_front();
260
261 //check leaf
262 if(!node.leftId == 0){
263 reordered_tree.push_back(node);
264 }else{
265 reordered_tree.push_back(node);
266 reordered_tree.back().leftId = nodeId+1;
267 reordered_tree.back().rightIdOrIndex = nodeId+2;
268 nodeId += 2;
269 bfs_queue.push_back(node.leftId);
270 bfs_queue.push_back(node.rightIdOrIndex);
271 }
272 }
273 //overwrite old tree with pruned tree
274 m_tree = std::move(reordered_tree);
275 }
276
277 /// Find the leaf of the tree for a sample
278 template<class Vector>
279 std::size_t findLeaf(Vector const& pattern) const{
280 std::size_t nodeId = 0;
281 while(m_tree[nodeId].leftId != 0){
282 if(pattern[m_tree[nodeId].attributeIndex] <= m_tree[nodeId].attributeValue){
283 //Branch on left node
284 nodeId = m_tree[nodeId].leftId;
285 }else{
286 //Branch on right node
287 nodeId = m_tree[nodeId].rightIdOrIndex;
288 }
289 }
290 return nodeId;
291 }
292
293private:
294 /// tree of the model
295 TreeType m_tree;
296 std::vector<LabelType> m_labels;
297 ///Number of attributes (set by trainer)
298 std::size_t m_inputDimension;
299 Shape m_outputShape;
300
301 /// Evaluate the CART tree on a single sample
302 template<class Vector>
303 LabelType const& evalPattern(Vector const& pattern) const{
304 auto nodeId = findLeaf(pattern);
305 return m_labels[m_tree[nodeId].rightIdOrIndex];
306 }
307};
308
309
310}
311#endif