45#ifndef SHARK_DATA_WEIGHTED_DATASET_H 
   46#define SHARK_DATA_WEIGHTED_DATASET_H 
   52template<
class DataType, 
class WeightType>
 
   59    template<
class DataT, 
class WeightT>
 
   65    template<
class DataT, 
class WeightT>
 
   70    template<
class DataT, 
class WeightT>
 
 
   83template<
class D1, 
class W1, 
class D2, 
class W2>
 
   86    swap(std::forward<D1>(p1.data),std::forward<D2>(p2.data));
 
   87    swap(std::forward<W1>(p1.weight),std::forward<W2>(p2.weight));
 
 
   90template<
class DataBatchType,
class WeightBatchType>
 
  100        typename DataBatchTraits::value_type,
 
  101        typename WeightBatchTraits::value_type
 
  108        decltype(
getBatchElement(std::declval<
typename std::add_const<DataBatchType>::type&>(),0)),
 
  109        decltype(
getBatchElement(std::declval<
typename std::add_const<WeightBatchType>::type&>(),0))
 
  114    template<
class D, 
class W>
 
  122        std::size_t 
size,Pair 
const& p
 
 
  125    template<
class I, 
class L>
 
  133        return DataBatchTraits::size(
data);
 
 
 
  158template<
class D1, 
class W1, 
class D2, 
class W2>
 
  165template<
class DataType, 
class WeightType>
 
  167: 
public detail::SimpleBatch<
 
  168    WeightedDataBatch<typename detail::element_to_batch<DataType>::type, typename detail::element_to_batch<WeightType>::type>
 
 
  171template<
class DataType, 
class WeightType>
 
  173    typedef typename detail::batch_to_element<DataType>::type 
DataElem;
 
  174    typedef typename detail::batch_to_element<WeightType>::type 
WeightElem;
 
 
  180template <
class DataContainerT>
 
  184    typedef typename DataContainerT::element_type DataType;
 
  185    typedef double WeightType;
 
  186    typedef DataContainerT DataContainer;
 
  188    typedef typename DataContainer::IndexSet IndexSet;
 
  197        typename DataContainer::batch_type,
 
  198        typename WeightContainer::batch_type
 
  203        typename DataContainer::batch_reference,
 
  207        typename DataContainer::const_batch_reference,
 
  209    > const_batch_reference;
 
  214    typedef boost::iterator_range< detail::DataElementIterator<BaseWeightedDataset<DataContainer> > > element_range;
 
  215    typedef boost::iterator_range< detail::DataElementIterator<BaseWeightedDataset<DataContainer> 
const> > const_element_range;
 
  216    typedef detail::BatchRange<BaseWeightedDataset<DataContainer> > batch_range;
 
  217    typedef detail::BatchRange<BaseWeightedDataset<DataContainer> 
const> const_batch_range;
 
  224    const_element_range elements()
const{
 
  225        return const_element_range(
 
  226            detail::DataElementIterator<BaseWeightedDataset<DataContainer> 
const>(
this,0,0,0),
 
  227            detail::DataElementIterator<BaseWeightedDataset<DataContainer> 
const>(
this,numberOfBatches(),0,numberOfElements())
 
  234    element_range elements(){
 
  235        return element_range(
 
  236            detail::DataElementIterator<BaseWeightedDataset<DataContainer> >(
this,0,0,0),
 
  237            detail::DataElementIterator<BaseWeightedDataset<DataContainer> >(
this,numberOfBatches(),0,numberOfElements())
 
  245    const_batch_range batches()
const{
 
  246        return const_batch_range(
this);
 
  252    batch_range batches(){
 
  253        return batch_range(
this);
 
  257    std::size_t numberOfBatches()
 const{
 
  258        return m_data.numberOfBatches();
 
  261    std::size_t numberOfElements()
 const{
 
  262        return m_data.numberOfElements();
 
  267        return m_data.empty();
 
  271    DataContainer 
const& data()
 const{
 
  275    DataContainer& data(){
 
  280    WeightContainer 
const& weights()
 const{
 
  284    WeightContainer& weights(){
 
  291    BaseWeightedDataset()
 
  297    BaseWeightedDataset(std::size_t numBatches)
 
  298    : m_data(numBatches),m_weights(numBatches)
 
  308    BaseWeightedDataset(std::size_t size, element_type 
const& element, std::size_t 
batchSize)
 
  310    , m_weights(size,element.weight,
batchSize)
 
  317    BaseWeightedDataset(DataContainer 
const& data, 
Data<WeightType> const& weights)
 
  318    : m_data(data), m_weights(weights)
 
  320        SHARK_RUNTIME_CHECK(data.numberOfElements() == weights.
numberOfElements(), 
"[ BaseWeightedDataset::WeightedUnlabeledData] number of data and number of weights must agree");
 
  322        for(std::size_t i  = 0; i != data.numberOfBatches(); ++i){
 
  329    BaseWeightedDataset(DataContainer 
const& data, 
double weight)
 
  330    : m_data(data), m_weights(data.numberOfBatches())
 
  332        for(std::size_t i = 0; i != numberOfBatches(); ++i){
 
  339    element_reference element(std::size_t i){
 
  340        return *(detail::DataElementIterator<BaseWeightedDataset<DataContainer> >(
this,0,0,0)+i);
 
  342    const_element_reference element(std::size_t i)
 const{
 
  343        return *(detail::DataElementIterator<BaseWeightedDataset<DataContainer> 
const>(
this,0,0,0)+i);
 
  347    batch_reference batch(std::size_t i){
 
  348        return batch_reference(m_data.batch(i),m_weights.batch(i));
 
  350    const_batch_reference batch(std::size_t i)
 const{
 
  351        return const_batch_reference(m_data.batch(i),m_weights.batch(i));
 
  369    virtual void makeIndependent(){
 
  370        m_weights.makeIndependent();
 
  371        m_data.makeIndependent();
 
  379    void splitBatch(std::size_t batch, std::size_t elementIndex){
 
  380        m_data.splitBatch(batch,elementIndex);
 
  381        m_weights.splitBatch(batch,elementIndex);
 
  388    void append(BaseWeightedDataset 
const& other){
 
  389        m_data.append(other.m_data);
 
  390        m_weights.append(other.m_weights);
 
  398    template<
class Range>
 
  399    void repartition(Range 
const& batchSizes){
 
  400        m_data.repartition(batchSizes);
 
  401        m_weights.repartition(batchSizes);
 
  408    std::vector<std::size_t> getPartitioning()
const{
 
  409        return m_data.getPartitioning();
 
  412    friend void swap( BaseWeightedDataset& a, BaseWeightedDataset& b){
 
  413        swap(a.m_data,b.m_data);
 
  414        swap(a.m_weights,b.m_weights);
 
  421    BaseWeightedDataset indexedSubset(IndexSet 
const& indices)
 const{
 
  422        BaseWeightedDataset 
subset;
 
  423        subset.m_data = m_data.indexedSubset(indices);
 
  424        subset.m_weights = m_weights.indexedSubset(indices);
 
  428    DataContainer m_data;               
 
  429    WeightContainer m_weights; 
 
  445template <
class DataT>
 
  449    typedef detail::BaseWeightedDataset <UnlabeledData<DataT> > base_type;
 
  451    using base_type::data;
 
  452    using base_type::weights;
 
  455    typedef typename base_type::element_type element_type;
 
  470    : base_type(numBatches)
 
 
  488    : base_type(data,weights)
 
 
  493    : base_type(data,weight)
 
 
  509        return data().shape();
 
 
  514        return data().shape();
 
 
  525        swap(
static_cast<base_type&
>(a),
static_cast<base_type&
>(b));
 
 
 
  532    for(
auto elem: d.elements())
 
  533        stream << elem.weight << 
" [" << elem.data<<
"]"<< 
"\n";
 
 
  538template<
class DataRange, 
class WeightRange>
 
  539typename boost::disable_if<
 
  540    boost::is_arithmetic<WeightRange>,
 
  541    WeightedUnlabeledData<
 
  542        typename boost::range_value<DataRange>::type
 
  548    typedef typename boost::range_value<DataRange>::type 
Data;
 
 
  577template <
class InputT, 
class LabelT>
 
  581    typedef detail::BaseWeightedDataset <LabeledData<InputT,LabelT> > base_type;
 
  587    typedef typename base_type::element_type element_type;
 
  589    using base_type::data;
 
  590    using base_type::weights;
 
  604    : base_type(numBatches)
 
 
  622    : base_type(data,weights)
 
 
  627    : base_type(data,weight)
 
 
  641        return data().labels();
 
 
  645        return data().labels();
 
 
  682        swap(
static_cast<base_type&
>(a),
static_cast<base_type&
>(b));
 
 
 
  687template<
class T, 
class U>
 
  689    for(
auto elem: d.elements())
 
  690        stream << elem.weight <<
" ("<< elem.data.label << 
" [" << elem.data.input<<
"] )"<< 
"\n";
 
 
  706template <
class InputType>
 
  712template <
class InputType, 
class LabelType>
 
  718template <
class InputType, 
class LabelType>
 
  723template <
class InputType>
 
  729template<
class InputType, 
class LabelType>
 
  735template<
class InputType>
 
  737    double weightSum = 0;
 
  738    for(std::size_t i = 0; i != dataset.numberOfBatches(); ++i){
 
  739        weightSum += sum(dataset.batch(i).weight);
 
 
  744template<
class InputType, 
class LabelType>
 
  746    double weightSum = 0;
 
  747    for(std::size_t i = 0; i != dataset.numberOfBatches(); ++i){
 
  748        weightSum += sum(dataset.batch(i).weight);
 
 
  754template<
class InputType>
 
  757    for(
auto const& elem: dataset.elements()){
 
  758        weights(elem.data.label) += elem.weight;
 
 
  766template<
class InputRange,
class LabelRange, 
class WeightRange>
 
  767typename boost::disable_if<
 
  768    boost::is_arithmetic<WeightRange>,
 
  770        typename boost::range_value<InputRange>::type,
 
  771        typename boost::range_value<LabelRange>::type
 
  776    "number of inputs and number of labels must agree");
 
  778    "number of data points and number of weights must agree");
 
  779    typedef typename boost::range_value<InputRange>::type 
InputType;
 
  780    typedef typename boost::range_value<LabelRange>::type LabelType;
 
 
  800template<
class InputType, 
class LabelType>
 
  803    std::size_t bootStrapSize = 0
 
  805    if(bootStrapSize == 0)
 
  810    for(std::size_t i = 0; i != bootStrapSize; ++i){
 
  812        bootstrapSet.element(index).weight += 1.0;
 
 
  828template<
class InputType>
 
  831    std::size_t bootStrapSize = 0
 
  833    if(bootStrapSize == 0)
 
  838    for(std::size_t i = 0; i != bootStrapSize; ++i){
 
  840        bootstrapSet.element(index).weight += 1.0;