167 SHARK_RUNTIME_CHECK(
m_nhp == parameters.size(),
"[SvmLogisticInterpretation::evalDerivative] wrong number of parameters");
173 std::vector< unsigned int > tmp_helper_labels(
m_numSamples);
174 std::vector< RealVector > tmp_helper_preds(
m_numSamples);
176 unsigned int next_label = 0;
199 csvm_trainer.
train(svm, cur_train_data);
203 for (std::size_t j=0; j<cur_vsize; j++) {
205 tmp_helper_labels[next_label] = cur_vlabels.
element(j);
206 tmp_helper_preds[next_label] = cur_vscores.
element(j);
209 noalias(row(all_validation_predict_derivs, next_label)) = der;
217 LinearModel<> logistic_model = fitLogistic(validation_dataset);
221 derivative.resize(
m_nhp);
224 std::size_t start = 0;
225 for(
auto const& batch: validation_dataset.
batches()){
226 std::size_t end = start+batch.size();
228 RealMatrix lossGradient;
229 error += logistic_loss.evalDerivative(batch.label,logistic_model(batch.input),lossGradient);
230 noalias(derivative) += column(lossGradient,0) % rows(all_validation_predict_derivs,start,end);