35#ifndef SHARK_DATA_CVDATASETTOOLS_H
36#define SHARK_DATA_CVDATASETTOOLS_H
48template<
class DatasetTypeT>
52 typedef typename DatasetType::IndexSet
IndexSet;
61 std::vector<IndexSet>
const &validationIndizes
62 ) : m_dataset(set),m_validationFolds(validationIndizes) {}
66 std::vector<std::size_t>
const &foldStart
68 for (std::size_t partition = 0; partition != foldStart.size(); partition++) {
69 std::size_t partitionSize = (partition+1 == foldStart.size()) ? set.numberOfBatches() : foldStart[partition+1];
70 partitionSize -= foldStart[partition];
73 IndexSet validationIndizes(partitionSize);
74 for (std::size_t batch = 0; batch != partitionSize; ++batch) {
75 validationIndizes[batch]=batch+foldStart[partition];
77 m_validationFolds.push_back(validationIndizes);
93 return m_validationFolds[i];
99 detail::complement(m_validationFolds[i], m_dataset.numberOfBatches(), trainingFold);
105 return m_validationFolds.size();
120 std::vector<IndexSet> m_validationFolds;
121 std::size_t m_datasetSize;
122 std::vector<std::size_t> m_validationFoldSizes;
135template<
class I,
class L>
138 std::size_t numberOfPartitions,
139 std::vector< std::vector<std::size_t> > members,
148 std::size_t numClasses = members.size();
151 for (std::size_t c = 0; c != numClasses; c++) {
156 std::size_t nn = numInputs / numberOfPartitions;
157 std::size_t leftOver = numInputs % numberOfPartitions;
158 std::vector<std::size_t> validationSize(numberOfPartitions,nn);
159 for (std::size_t partition = 0; partition != leftOver; partition++) {
160 validationSize[partition]++;
164 std::vector<std::size_t> partitionStart;
165 std::vector<std::size_t> batchSizes;
166 std::size_t numBatches = batchPartitioning(validationSize,partitionStart,batchSizes,
batchSize);
171 std::vector<std::size_t> validationSetStart = partitionStart;
173 std::size_t fold = 0;
174 std::vector<std::vector<std::size_t> > batchElements(numberOfPartitions);
177 if ( cv_indices != NULL ) {
178 cv_indices->first.clear();
179 cv_indices->first.resize( numInputs );
180 cv_indices->second.clear();
181 cv_indices->second.resize( numInputs );
185 for (std::size_t c = 0; c != numClasses; c++) {
186 for (std::size_t i = 0; i != members[c].size(); i++) {
187 std::size_t oldPos = members[c][i];
188 std::size_t batchNumber = validationSetStart[fold];
190 batchElements[fold].push_back(oldPos);
192 if ( cv_indices != NULL ) {
193 cv_indices->first[ j ] = oldPos;
194 cv_indices->second[ j ] = fold;
199 if (batchElements[fold].size() == batchSizes[batchNumber]) {
200 newSet.
batch(validationSetStart[fold]) =
subBatch(setView,batchElements[fold]);
201 batchElements[fold].clear();
202 ++validationSetStart[fold];
205 fold = (fold+1) % numberOfPartitions;
239template<
class I,
class L>
241 std::size_t numberOfPartitions,
260template<
class I,
class L>
265 std::vector<std::size_t> validationSize(numberOfPartitions);
266 std::size_t inputsForValidation = numInputs / numberOfPartitions;
267 std::size_t leftOver = numInputs - inputsForValidation * numberOfPartitions;
268 for (std::size_t i = 0; i != numberOfPartitions; i++) {
269 std::size_t vs=inputsForValidation+(i<leftOver);
270 validationSize[i] =vs;
274 std::vector<std::size_t> partitionStart;
275 std::vector<std::size_t> batchSizes;
276 detail::batchPartitioning(validationSize,partitionStart,batchSizes,
batchSize);
300 std::size_t numberOfPartitions,
305 std::size_t numInputs = setView.
size();
310 std::vector< std::vector<std::size_t> > members(numClasses);
311 for (std::size_t i = 0; i != numInputs; i++) {
312 members[setView[i].label].push_back(i);
314 return detail::createCVSameSizeBalanced(set, numberOfPartitions, members,
batchSize, cv_indices);
326template<
class I,
class L>
329 std::size_t numberOfPartitions
338 std::vector<IndexSet> folds;
340 std::size_t remainder = set.
numberOfBatches() - partitionSize*numberOfPartitions;
341 std::vector<std::size_t>::iterator pos = indizes.begin();
342 for(std::size_t i = 0; i!= numberOfPartitions; ++i){
343 std::size_t size = partitionSize;
348 folds.push_back(IndexSet(pos,pos+size));
365template<
class I,
class L>
368 std::size_t numberOfPartitions,
369 std::vector<std::size_t> indices,
374 SIZE_CHECK(numberOfPartitions == *std::max_element(indices.begin(),indices.end())+1);
377 std::vector<std::size_t> validationSize(numberOfPartitions,0);
378 for (std::size_t input = 0; input != numInputs; input++) {
379 validationSize[indices[input]]++;
383 std::vector<std::size_t> partitionStart;
384 std::vector<std::size_t> batchSizes;
385 std::size_t numBatches = detail::batchPartitioning(validationSize,partitionStart,batchSizes,
batchSize);
390 std::vector<std::size_t> validationSetStart = partitionStart;
391 std::vector<std::vector<std::size_t> > batchElements(numberOfPartitions);
392 for (std::size_t input = 0; input != numInputs; input++) {
393 std::size_t partition = indices[input];
394 batchElements[partition].push_back(input);
397 std::size_t batchNumber = validationSetStart[partition];
398 if (batchElements[partition].size() == batchSizes[batchNumber]) {
399 newSet.
batch(validationSetStart[partition]) =
subBatch(setView,batchElements[partition]);
400 batchElements[partition].clear();
401 ++validationSetStart[partition];
426template<
class I,
class L>
429 std::size_t numberOfPartitions,
434 SIZE_CHECK(indices.first.size() == numInputs);
435 SIZE_CHECK(indices.second.size() == numInputs);
436 SIZE_CHECK(numberOfPartitions == *std::max_element(indices.second.begin(),indices.second.end())+1);
439 std::vector<std::size_t> validationSize(numberOfPartitions,0);
440 for (std::size_t input = 0; input != numInputs; input++) {
441 validationSize[indices.second[input]]++;
445 std::vector<std::size_t> partitionStart;
446 std::vector<std::size_t> batchSizes;
447 std::size_t numBatches = detail::batchPartitioning(validationSize,partitionStart,batchSizes,
batchSize);
452 std::vector<std::size_t> validationSetStart = partitionStart;
453 std::vector<std::vector<std::size_t> > batchElements(numberOfPartitions);
454 for (std::size_t input = 0; input != numInputs; input++) {
455 std::size_t partition = indices.second[input];
456 batchElements[partition].push_back( indices.first[input] );
459 std::size_t batchNumber = validationSetStart[partition];
460 if (batchElements[partition].size() == batchSizes[batchNumber]) {
461 newSet.
batch(validationSetStart[partition]) =
subBatch(setView,batchElements[partition]);
462 batchElements[partition].clear();
463 ++validationSetStart[partition];