shark::CrossValidationError< ModelTypeT, LabelTypeT > Class Template Reference

Cross-validation error for selection of hyper-parameters. More...

#include <shark/ObjectiveFunctions/CrossValidationError.h>

+ Inheritance diagram for shark::CrossValidationError< ModelTypeT, LabelTypeT >:

Public Types

typedef ModelTypeT::InputType InputType
 
typedef ModelTypeT::OutputType OutputType
 
typedef LabelTypeT LabelType
 
typedef LabeledData< InputType, LabelTypeDatasetType
 
typedef CVFolds< DatasetTypeFoldsType
 
typedef ModelTypeT ModelType
 
typedef AbstractTrainer< ModelType, LabelTypeTrainerType
 
typedef AbstractCost< LabelType, OutputTypeCostType
 
- Public Types inherited from shark::AbstractObjectiveFunction< RealVector, double >
enum  Feature
 List of features that are supported by an implementation. More...
 
typedef RealVector SearchPointType
 
typedef double ResultType
 
typedef boost::mpl::if_< std::is_arithmetic< double >, SearchPointType, RealMatrix >::type FirstOrderDerivative
 
typedef TypedFlags< FeatureFeatures
 This statement declares the member m_features. See Core/Flags.h for details.
 
typedef TypedFeatureNotAvailableException< FeatureFeatureNotAvailableException
 

Public Member Functions

 CrossValidationError (FoldsType const &dataFolds, IParameterizable<> *meta, ModelType *model, TrainerType *trainer, CostType *cost)
 
std::string name () const
 From INameable: return the class name.
 
std::size_t numberOfVariables () const
 Accesses the number of variables.
 
double eval (RealVector const &parameters) const
 
- Public Member Functions inherited from shark::AbstractObjectiveFunction< RealVector, double >
const Featuresfeatures () const
 
virtual void updateFeatures ()
 
bool hasValue () const
 returns whether this function can calculate it's function value
 
bool hasFirstDerivative () const
 returns whether this function can calculate the first derivative
 
bool hasSecondDerivative () const
 returns whether this function can calculate the second derivative
 
bool canProposeStartingPoint () const
 returns whether this function can propose a starting point.
 
bool isConstrained () const
 returns whether this function can return
 
bool hasConstraintHandler () const
 returns whether this function can return
 
bool canProvideClosestFeasible () const
 Returns whether this function can calculate thee closest feasible to an infeasible point.
 
bool isThreadSafe () const
 Returns true, when the function can be usd in parallel threads.
 
bool isNoisy () const
 Returns true, when the function can be usd in parallel threads.
 
 AbstractObjectiveFunction ()
 Default ctor.
 
virtual ~AbstractObjectiveFunction ()
 Virtual destructor.
 
virtual void init ()
 
void setRng (random::rng_type *rng)
 Sets the Rng used by the objective function.
 
virtual bool hasScalableDimensionality () const
 
virtual void setNumberOfVariables (std::size_t numberOfVariables)
 Adjusts the number of variables if the function is scalable.
 
virtual std::size_t numberOfObjectives () const
 
virtual bool hasScalableObjectives () const
 
virtual void setNumberOfObjectives (std::size_t numberOfObjectives)
 Adjusts the number of objectives if the function is scalable.
 
std::size_t evaluationCounter () const
 Accesses the evaluation counter of the function.
 
AbstractConstraintHandler< SearchPointType > const & getConstraintHandler () const
 Returns the constraint handler of the function if it has one.
 
virtual bool isFeasible (const SearchPointType &input) const
 Tests whether a point in SearchSpace is feasible, e.g., whether the constraints are fulfilled.
 
virtual void closestFeasible (SearchPointType &input) const
 If supported, the supplied point is repaired such that it satisfies all of the function's constraints.
 
virtual SearchPointType proposeStartingPoint () const
 Proposes a starting point in the feasible search space of the function.
 
ResultType operator() (SearchPointType const &input) const
 Evaluates the function. Useful together with STL-Algorithms like std::transform.
 
virtual ResultType evalDerivative (SearchPointType const &input, FirstOrderDerivative &derivative) const
 Evaluates the objective function and calculates its gradient.
 
virtual ResultType evalDerivative (SearchPointType const &input, SecondOrderDerivative &derivative) const
 Evaluates the objective function and calculates its gradient.
 
- Public Member Functions inherited from shark::INameable
virtual ~INameable ()
 

Additional Inherited Members

- Protected Member Functions inherited from shark::AbstractObjectiveFunction< RealVector, double >
void announceConstraintHandler (AbstractConstraintHandler< SearchPointType > const *handler)
 helper function which is called to announce the presence of an constraint handler.
 
- Protected Attributes inherited from shark::AbstractObjectiveFunction< RealVector, double >
Features m_features
 
std::size_t m_evaluationCounter
 Evaluation counter, default value: 0.
 
AbstractConstraintHandler< SearchPointType > const * m_constraintHandler
 
random::rng_type * mep_rng
 

Detailed Description

template<class ModelTypeT, class LabelTypeT = typename ModelTypeT::OutputType>
class shark::CrossValidationError< ModelTypeT, LabelTypeT >

Cross-validation error for selection of hyper-parameters.

The cross-validation error is useful for evaluating how well a model performs on a problem. It is regularly used for model selection.
In Shark, the cross-validation procedure is abstracted as follows: First, the given point is written into an IParameterizable object (such as a regularizer or a trainer). Then a model is trained with a trainer with the given settings on a number of folds and evaluated on the corresponding validation sets with a cost function. The average cost function value over all folds is returned.
Thus, the cross-validation procedure requires a "meta" IParameterizable object, a model, a trainer, a data set, and a cost function.

Definition at line 72 of file CrossValidationError.h.

Member Typedef Documentation

◆ CostType

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
typedef AbstractCost<LabelType, OutputType> shark::CrossValidationError< ModelTypeT, LabelTypeT >::CostType

Definition at line 82 of file CrossValidationError.h.

◆ DatasetType

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
typedef LabeledData<InputType, LabelType> shark::CrossValidationError< ModelTypeT, LabelTypeT >::DatasetType

Definition at line 78 of file CrossValidationError.h.

◆ FoldsType

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
typedef CVFolds<DatasetType> shark::CrossValidationError< ModelTypeT, LabelTypeT >::FoldsType

Definition at line 79 of file CrossValidationError.h.

◆ InputType

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
typedef ModelTypeT::InputType shark::CrossValidationError< ModelTypeT, LabelTypeT >::InputType

Definition at line 75 of file CrossValidationError.h.

◆ LabelType

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
typedef LabelTypeT shark::CrossValidationError< ModelTypeT, LabelTypeT >::LabelType

Definition at line 77 of file CrossValidationError.h.

◆ ModelType

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
typedef ModelTypeT shark::CrossValidationError< ModelTypeT, LabelTypeT >::ModelType

Definition at line 80 of file CrossValidationError.h.

◆ OutputType

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
typedef ModelTypeT::OutputType shark::CrossValidationError< ModelTypeT, LabelTypeT >::OutputType

Definition at line 76 of file CrossValidationError.h.

◆ TrainerType

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
typedef AbstractTrainer<ModelType, LabelType> shark::CrossValidationError< ModelTypeT, LabelTypeT >::TrainerType

Definition at line 81 of file CrossValidationError.h.

Constructor & Destructor Documentation

◆ CrossValidationError()

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
shark::CrossValidationError< ModelTypeT, LabelTypeT >::CrossValidationError ( FoldsType const &  dataFolds,
IParameterizable<> *  meta,
ModelType model,
TrainerType trainer,
CostType cost 
)
inline

Definition at line 95 of file CrossValidationError.h.

Member Function Documentation

◆ eval()

◆ name()

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
std::string shark::CrossValidationError< ModelTypeT, LabelTypeT >::name ( ) const
inlinevirtual

From INameable: return the class name.

Reimplemented from shark::INameable.

Definition at line 109 of file CrossValidationError.h.

References shark::INameable::name().

◆ numberOfVariables()

template<class ModelTypeT , class LabelTypeT = typename ModelTypeT::OutputType>
std::size_t shark::CrossValidationError< ModelTypeT, LabelTypeT >::numberOfVariables ( ) const
inlinevirtual

The documentation for this class was generated from the following file: