45#ifndef SHARK_DATA_WEIGHTED_DATASET_H
46#define SHARK_DATA_WEIGHTED_DATASET_H
52template<
class DataType,
class WeightType>
59 template<
class DataT,
class WeightT>
65 template<
class DataT,
class WeightT>
70 template<
class DataT,
class WeightT>
83template<
class D1,
class W1,
class D2,
class W2>
86 swap(std::forward<D1>(p1.data),std::forward<D2>(p2.data));
87 swap(std::forward<W1>(p1.weight),std::forward<W2>(p2.weight));
90template<
class DataBatchType,
class WeightBatchType>
100 typename DataBatchTraits::value_type,
101 typename WeightBatchTraits::value_type
108 decltype(
getBatchElement(std::declval<
typename std::add_const<DataBatchType>::type&>(),0)),
109 decltype(
getBatchElement(std::declval<
typename std::add_const<WeightBatchType>::type&>(),0))
114 template<
class D,
class W>
122 std::size_t
size,Pair
const& p
125 template<
class I,
class L>
133 return DataBatchTraits::size(
data);
158template<
class D1,
class W1,
class D2,
class W2>
165template<
class DataType,
class WeightType>
167:
public detail::SimpleBatch<
168 WeightedDataBatch<typename detail::element_to_batch<DataType>::type, typename detail::element_to_batch<WeightType>::type>
171template<
class DataType,
class WeightType>
173 typedef typename detail::batch_to_element<DataType>::type
DataElem;
174 typedef typename detail::batch_to_element<WeightType>::type
WeightElem;
180template <
class DataContainerT>
184 typedef typename DataContainerT::element_type DataType;
185 typedef double WeightType;
186 typedef DataContainerT DataContainer;
188 typedef typename DataContainer::IndexSet IndexSet;
197 typename DataContainer::batch_type,
198 typename WeightContainer::batch_type
203 typename DataContainer::batch_reference,
207 typename DataContainer::const_batch_reference,
209 > const_batch_reference;
214 typedef boost::iterator_range< detail::DataElementIterator<BaseWeightedDataset<DataContainer> > > element_range;
215 typedef boost::iterator_range< detail::DataElementIterator<BaseWeightedDataset<DataContainer>
const> > const_element_range;
216 typedef detail::BatchRange<BaseWeightedDataset<DataContainer> > batch_range;
217 typedef detail::BatchRange<BaseWeightedDataset<DataContainer>
const> const_batch_range;
224 const_element_range elements()
const{
225 return const_element_range(
226 detail::DataElementIterator<BaseWeightedDataset<DataContainer>
const>(
this,0,0,0),
227 detail::DataElementIterator<BaseWeightedDataset<DataContainer>
const>(
this,numberOfBatches(),0,numberOfElements())
234 element_range elements(){
235 return element_range(
236 detail::DataElementIterator<BaseWeightedDataset<DataContainer> >(
this,0,0,0),
237 detail::DataElementIterator<BaseWeightedDataset<DataContainer> >(
this,numberOfBatches(),0,numberOfElements())
245 const_batch_range batches()
const{
246 return const_batch_range(
this);
252 batch_range batches(){
253 return batch_range(
this);
257 std::size_t numberOfBatches()
const{
258 return m_data.numberOfBatches();
261 std::size_t numberOfElements()
const{
262 return m_data.numberOfElements();
267 return m_data.empty();
271 DataContainer
const& data()
const{
275 DataContainer& data(){
280 WeightContainer
const& weights()
const{
284 WeightContainer& weights(){
291 BaseWeightedDataset()
297 BaseWeightedDataset(std::size_t numBatches)
298 : m_data(numBatches),m_weights(numBatches)
308 BaseWeightedDataset(std::size_t size, element_type
const& element, std::size_t
batchSize)
310 , m_weights(size,element.weight,
batchSize)
317 BaseWeightedDataset(DataContainer
const& data,
Data<WeightType> const& weights)
318 : m_data(data), m_weights(weights)
320 SHARK_RUNTIME_CHECK(data.numberOfElements() == weights.
numberOfElements(),
"[ BaseWeightedDataset::WeightedUnlabeledData] number of data and number of weights must agree");
322 for(std::size_t i = 0; i != data.numberOfBatches(); ++i){
329 BaseWeightedDataset(DataContainer
const& data,
double weight)
330 : m_data(data), m_weights(data.numberOfBatches())
332 for(std::size_t i = 0; i != numberOfBatches(); ++i){
339 element_reference element(std::size_t i){
340 return *(detail::DataElementIterator<BaseWeightedDataset<DataContainer> >(
this,0,0,0)+i);
342 const_element_reference element(std::size_t i)
const{
343 return *(detail::DataElementIterator<BaseWeightedDataset<DataContainer>
const>(
this,0,0,0)+i);
347 batch_reference batch(std::size_t i){
348 return batch_reference(m_data.batch(i),m_weights.batch(i));
350 const_batch_reference batch(std::size_t i)
const{
351 return const_batch_reference(m_data.batch(i),m_weights.batch(i));
369 virtual void makeIndependent(){
370 m_weights.makeIndependent();
371 m_data.makeIndependent();
379 void splitBatch(std::size_t batch, std::size_t elementIndex){
380 m_data.splitBatch(batch,elementIndex);
381 m_weights.splitBatch(batch,elementIndex);
388 void append(BaseWeightedDataset
const& other){
389 m_data.append(other.m_data);
390 m_weights.append(other.m_weights);
398 template<
class Range>
399 void repartition(Range
const& batchSizes){
400 m_data.repartition(batchSizes);
401 m_weights.repartition(batchSizes);
408 std::vector<std::size_t> getPartitioning()
const{
409 return m_data.getPartitioning();
412 friend void swap( BaseWeightedDataset& a, BaseWeightedDataset& b){
413 swap(a.m_data,b.m_data);
414 swap(a.m_weights,b.m_weights);
421 BaseWeightedDataset indexedSubset(IndexSet
const& indices)
const{
422 BaseWeightedDataset
subset;
423 subset.m_data = m_data.indexedSubset(indices);
424 subset.m_weights = m_weights.indexedSubset(indices);
428 DataContainer m_data;
429 WeightContainer m_weights;
445template <
class DataT>
449 typedef detail::BaseWeightedDataset <UnlabeledData<DataT> > base_type;
451 using base_type::data;
452 using base_type::weights;
455 typedef typename base_type::element_type element_type;
470 : base_type(numBatches)
488 : base_type(data,weights)
493 : base_type(data,weight)
509 return data().shape();
514 return data().shape();
525 swap(
static_cast<base_type&
>(a),
static_cast<base_type&
>(b));
532 for(
auto elem: d.elements())
533 stream << elem.weight <<
" [" << elem.data<<
"]"<<
"\n";
538template<
class DataRange,
class WeightRange>
539typename boost::disable_if<
540 boost::is_arithmetic<WeightRange>,
541 WeightedUnlabeledData<
542 typename boost::range_value<DataRange>::type
548 typedef typename boost::range_value<DataRange>::type
Data;
577template <
class InputT,
class LabelT>
581 typedef detail::BaseWeightedDataset <LabeledData<InputT,LabelT> > base_type;
587 typedef typename base_type::element_type element_type;
589 using base_type::data;
590 using base_type::weights;
604 : base_type(numBatches)
622 : base_type(data,weights)
627 : base_type(data,weight)
641 return data().labels();
645 return data().labels();
682 swap(
static_cast<base_type&
>(a),
static_cast<base_type&
>(b));
687template<
class T,
class U>
689 for(
auto elem: d.elements())
690 stream << elem.weight <<
" ("<< elem.data.label <<
" [" << elem.data.input<<
"] )"<<
"\n";
706template <
class InputType>
712template <
class InputType,
class LabelType>
718template <
class InputType,
class LabelType>
723template <
class InputType>
729template<
class InputType,
class LabelType>
735template<
class InputType>
737 double weightSum = 0;
738 for(std::size_t i = 0; i != dataset.numberOfBatches(); ++i){
739 weightSum += sum(dataset.batch(i).weight);
744template<
class InputType,
class LabelType>
746 double weightSum = 0;
747 for(std::size_t i = 0; i != dataset.numberOfBatches(); ++i){
748 weightSum += sum(dataset.batch(i).weight);
754template<
class InputType>
757 for(
auto const& elem: dataset.elements()){
758 weights(elem.data.label) += elem.weight;
766template<
class InputRange,
class LabelRange,
class WeightRange>
767typename boost::disable_if<
768 boost::is_arithmetic<WeightRange>,
770 typename boost::range_value<InputRange>::type,
771 typename boost::range_value<LabelRange>::type
776 "number of inputs and number of labels must agree");
778 "number of data points and number of weights must agree");
779 typedef typename boost::range_value<InputRange>::type
InputType;
780 typedef typename boost::range_value<LabelRange>::type LabelType;
800template<
class InputType,
class LabelType>
803 std::size_t bootStrapSize = 0
805 if(bootStrapSize == 0)
810 for(std::size_t i = 0; i != bootStrapSize; ++i){
812 bootstrapSet.element(index).weight += 1.0;
828template<
class InputType>
831 std::size_t bootStrapSize = 0
833 if(bootStrapSize == 0)
838 for(std::size_t i = 0; i != bootStrapSize; ++i){
840 bootstrapSet.element(index).weight += 1.0;