MNIST.h
Go to the documentation of this file.
1/*!
2 * \brief Loads the MNIST benchmark problem.
3 *
4 * \author O. Krause, A.Fischer, K.Bruegge
5 * \date 2012
6 *
7 *
8 * \par Copyright 1995-2017 Shark Development Team
9 *
10 * <BR><HR>
11 * This file is part of Shark.
12 * <https://shark-ml.github.io/Shark/>
13 *
14 * Shark is free software: you can redistribute it and/or modify
15 * it under the terms of the GNU Lesser General Public License as published
16 * by the Free Software Foundation, either version 3 of the License, or
17 * (at your option) any later version.
18 *
19 * Shark is distributed in the hope that it will be useful,
20 * but WITHOUT ANY WARRANTY; without even the implied warranty of
21 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
22 * GNU Lesser General Public License for more details.
23 *
24 * You should have received a copy of the GNU Lesser General Public License
25 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
26 *
27 */
28#ifndef UNSUPERVISED_RBM_PROBLEMS_MNIST_H
29#define UNSUPERVISED_RBM_PROBLEMS_MNIST_H
30
31#include <shark/Data/Dataset.h>
32#include <shark/LinAlg/Base.h>
33#include <shark/Core/Random.h>
34
35#include <sstream>
36#include <fstream>
37#include <string>
38namespace shark{
39
40/// \brief Reads in the famous MNIST data in possibly binarized form. The MNIST database itself is not included in Shark,
41/// this class just helps loading it.
42///
43///MNIST is a set of handwritten digits.
44///It needs the filename of the file containing the database (can be downloaded form the web)
45///and the threshold for binarization. The threshold (between 0 and 255) describes when a gray value will be interpreted
46///as 1. Default is 127. If the threshold is 0, no binarization takes place.
47class MNIST{
48private:
50 std::string m_filename;
51 char m_threshold;
52 std::size_t m_batchSize;
53
54 int readInt (unsigned char *memblock) const{
55 return ((int)memblock[0] << 24) + ((int)memblock[1] << 16) + ((int)memblock[2] << 8) + memblock[3];
56 }
57 void init(){
58 //m_name="MNIST";
59 std::ifstream infile(m_filename.c_str(), std::ios::binary);
60 SHARK_RUNTIME_CHECK(infile, "Can not open file!");
61
62 //get file size
63 infile.seekg(0,std::ios::end);
64 std::ifstream::pos_type inputSize = infile.tellg();
65
66
67 unsigned char *memblock = new unsigned char [inputSize];
68 infile.seekg (0, std::ios::beg);
69 infile.read ((char *) memblock, inputSize);
70
71 SHARK_RUNTIME_CHECK(readInt(memblock) == 2051, "magic number for mnist wrong!");
72 std::size_t numImages = readInt(memblock + 4);
73 std::size_t numRows = readInt(memblock + 8);
74 std::size_t numColumns = readInt(memblock + 12);
75 std::size_t sizeOfVis = numRows * numColumns;
76
77 std::vector<RealVector> data(numImages,RealVector(sizeOfVis));
78 for (std::size_t i = 0; i != numImages; ++i){
79 RealVector imgVec(sizeOfVis);
80 if(m_threshold != 0){
81 for (size_t j = 0; j != sizeOfVis; ++j){
82 char pixel = memblock[ 16 + i * sizeOfVis + j ] > m_threshold;
83 data[i](j) = pixel;
84 }
85 }
86 else{
87 for (size_t j = 0; j != sizeOfVis; ++j){
88 data[i](j) = memblock[ 16 + i * sizeOfVis + j ];
89 }
90 }
91 }
92 delete [] memblock;
93 m_data = createDataFromRange(data,m_batchSize);
94 }
95public:
96
97 //Constructor. Sets the configurations from a property tree and imports the data set.
98 //@param filename the name of the file storing the dataset
99 //@param threshhold the threshold for turning gray values into ones
100 //@param batchSize the size of the batch
101 MNIST(std::string filename, char threshold = 127, std::size_t batchSize = 256)
102 : m_filename(filename), m_threshold(threshold), m_batchSize(batchSize){
103 init();
104 }
105
106 //Returns the data vector
108 return m_data;
109 }
110
111 //Returns the dimension of the pattern of MNIST.
112 std::size_t inputDimension() const {
113 return 28*28;
114 }
115
116 //Returns the batch size.
117 std::size_t batchSize() const {
118 return m_batchSize;
119 }
120
121};
122}
123#endif
124