29#ifndef REMORA_KERNELS_DEFAULT_PSTRF_HPP
30#define REMORA_KERNELS_DEFAULT_PSTRF_HPP
34#include "../../proxy_expressions.hpp"
35#include "../../dense.hpp"
37namespace remora{
namespace bindings {
39template<
class MatA,
class VecP>
41 matrix_expression<MatA, cpu_tag> &A,
42 vector_expression<VecP, cpu_tag>& P,
68 std::size_t block_size = 20;
71 size_t m = A().size1();
73 vector<typename MatA::value_type> pivotValues(m);
76 double max_diag = A()(0,0);
77 for(std::size_t i = 1; i < m; ++i)
78 max_diag = std::max(max_diag,std::abs(A()(i,i)));
79 double epsilon = m * m * std::numeric_limits<typename MatA::value_type>::epsilon() * max_diag;
81 for(std::size_t k = 0; k < m; k += block_size){
82 std::size_t currentSize = std::min(m-k,block_size);
84 auto Ak = subrange(A,k,m,k,m);
85 auto pivots = subrange(pivotValues,k,m);
88 for(std::size_t j = 0; j != currentSize; ++j){
92 for(std::size_t i = 0; i != m-k; ++i)
95 for(std::size_t i = j; i != m-k; ++i)
96 pivots(i) -= Ak(i,j-1) * Ak(i,j-1);
100 std::size_t pivot = std::max_element(pivots.begin()+j,pivots.end())-pivots.begin();
102 P()(k+j) = (
int)(pivot+k);
103 A().swap_rows(k+j,k+pivot);
104 A().swap_columns(k+j,k+pivot);
105 std::swap(pivots(j),pivots(pivot));
109 auto pivotValue = pivots(j);
110 if(pivotValue < epsilon){
112 subrange(Ak,j,m-k,j,m-k).clear();
117 Ak(j,j) = std::sqrt(pivotValue);
120 auto colLowerPart = subrange(column(Ak,j),j+1,m-k);
124 auto blockLL = subrange(Ak,j+1,m-k,0,j);
125 auto curRow = row(Ak,j);
126 auto rowLeftPart = subrange(curRow,0,j);
136 kernels::gemv(blockLL,rowLeftPart,colLowerPart,-1);
138 colLowerPart /= Ak(j,j);
140 subrange(Ak,j,j+1,j+1,Ak.size2()).clear();
142 if(k+currentSize < m){
143 auto blockLL = subrange(Ak, block_size, m-k, 0, block_size);
144 auto blockLR = subrange(Ak, block_size, m-k, block_size, m-k);
145 kernels::gemm(blockLL,trans(blockLL), blockLR, -1);
152template<
class MatA,
class VecP>
154 matrix_expression<MatA, cpu_tag> &A,
155 vector_expression<VecP, cpu_tag>& P,
158 auto transA = trans(A);
159 return pstrf(transA,P,lower());