Skip to content

Commit

Permalink
For BLR frontal matrix compression,
Browse files Browse the repository at this point in the history
scale the absolute tolerance with the front matrix norm.
Also tweak the default absolute BLR tolerance.
  • Loading branch information
pghysels committed Jul 13, 2023
1 parent bc9bd20 commit d6d3218
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 7 deletions.
11 changes: 11 additions & 0 deletions src/BLR/BLRMatrixMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ namespace strumpack {
return Comm().all_reduce(this->rank(), MPI_MAX);
}

template<typename scalar_t>
typename RealType<scalar_t>::value_type
BLRMatrixMPI<scalar_t>::normF() const {
real_t nrm2 = 0.;
for (auto& b : blocks_) {
auto nrm = b->normF();
nrm2 += nrm*nrm;
}
return std::sqrt(nrm2);
}

template<typename scalar_t> void
BLRMatrixMPI<scalar_t>::print(const std::string& name) {
std::cout << "BLR(" << name << ")="
Expand Down
3 changes: 3 additions & 0 deletions src/BLR/BLRMatrixMPI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ namespace strumpack {
*/
template<typename scalar_t> class BLRMatrixMPI
: public structured::StructuredMatrix<scalar_t> {
using real_t = typename RealType<scalar_t>::value_type;
using DenseM_t = DenseMatrix<scalar_t>;
using DenseMW_t = DenseMatrixWrapper<scalar_t>;
using DistM_t = DistributedMatrix<scalar_t>;
Expand All @@ -146,6 +147,8 @@ namespace strumpack {
std::size_t total_nonzeros() const;
std::size_t max_rank() const;

real_t normF() const;

const MPIComm& Comm() const { return grid_->Comm(); }

const ProcessorGrid2D* grid() const { return grid_; }
Expand Down
4 changes: 2 additions & 2 deletions src/BLR/BLROptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ namespace strumpack {
return real_t(1e-4);
}
template<typename real_t> inline real_t default_BLR_abs_tol() {
return real_t(1e-10);
return real_t(1e-12);
}
template<> inline float default_BLR_rel_tol() {
return 1e-2;
}
template<> inline float default_BLR_abs_tol() {
return 1e-5;
return 1e-6;
}

enum class LowRankAlgorithm { RRQR, ACA, BACA };
Expand Down
3 changes: 3 additions & 0 deletions src/BLR/BLRTile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace strumpack {
template<typename scalar_t> class DenseTile;

template<typename scalar_t> class BLRTile {
using real_t = typename RealType<scalar_t>::value_type;
using DenseM_t = DenseMatrix<scalar_t>;
using Opts_t = BLROptions<scalar_t>;
using DMW_t = DenseMatrixWrapper<scalar_t>;
Expand All @@ -63,6 +64,8 @@ namespace strumpack {
virtual void dense(DenseM_t& A) const = 0;
virtual DenseM_t dense() const = 0;

virtual real_t normF() const = 0;

virtual std::unique_ptr<BLRTile<scalar_t>> clone() const = 0;

virtual std::unique_ptr<LRTile<scalar_t>>
Expand Down
3 changes: 3 additions & 0 deletions src/BLR/DenseTile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace strumpack {

template<typename scalar_t> class DenseTile
: public BLRTile<scalar_t> {
using real_t = typename RealType<scalar_t>::value_type;
using DenseM_t = DenseMatrix<scalar_t>;
using DMW_t = DenseMatrixWrapper<scalar_t>;
using BLRT_t = BLRTile<scalar_t>;
Expand All @@ -61,6 +62,8 @@ namespace strumpack {
void dense(DenseM_t& A) const override { A = D_; }
DenseM_t dense() const override { return D_; }

real_t normF() const { return D_.normF(); }

std::unique_ptr<BLRTile<scalar_t>> clone() const override;

std::unique_ptr<LRTile<scalar_t>>
Expand Down
8 changes: 8 additions & 0 deletions src/BLR/LRTile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ namespace strumpack {
*/
template<typename scalar_t> class LRTile
: public BLRTile<scalar_t> {
using real_t = typename RealType<scalar_t>::value_type;
using DenseM_t = DenseMatrix<scalar_t>;
using DMW_t = DenseMatrixWrapper<scalar_t>;
using Opts_t = BLROptions<scalar_t>;
Expand Down Expand Up @@ -92,6 +93,13 @@ namespace strumpack {
void dense(DenseM_t& A) const override;
DenseM_t dense() const override;

real_t normF() const {
std::cerr << "WARNING: normF of compressed BLR matrix is not supported."
<< std::endl;
assert(false);
return 0.;
}

std::unique_ptr<BLRTile<scalar_t>> clone() const override;

std::unique_ptr<LRTile<scalar_t>>
Expand Down
12 changes: 9 additions & 3 deletions src/sparse/fronts/FrontalMatrixBLR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ namespace strumpack {
(sep_begin_, sep_end_, this->upd(), e11, e12, e21);
BLRM_t::construct_and_partial_factor_col
(F11blr_, F12blr_, F21blr_, F22blr_, sep_tiles_,
upd_tiles_, admissibility_, opts.BLR_options(),
upd_tiles_, admissibility_, blr_opts,
[&](int i, bool part, std::size_t CP) {
build_front_cols
(A, i, part, CP, e11, e12, e21, task_depth, opts);
Expand All @@ -335,10 +335,16 @@ namespace strumpack {
lchild_->extend_add_to_dense(F11, F12, F21, F22_, this, task_depth);
if (rchild_)
rchild_->extend_add_to_dense(F11, F12, F21, F22_, this, task_depth);
auto nF11 = F11.normF();
auto nF12 = F12.normF();
auto nF21 = F21.normF();
auto nF = std::sqrt(nF11*nF11 + nF12*nF12 + nF21*nF21);
auto lopts = blr_opts;
lopts.set_abs_tol(lopts.abs_tol() * nF);
if (dsep)
BLRM_t::construct_and_partial_factor
(F11, F12, F21, F22_, F11blr_, F12blr_, F21blr_,
sep_tiles_, upd_tiles_, admissibility_, opts.BLR_options());
sep_tiles_, upd_tiles_, admissibility_, lopts);
}
} else { // ACA or BACA
auto F11elem = [&](const std::vector<std::size_t>& lI,
Expand Down Expand Up @@ -380,7 +386,7 @@ namespace strumpack {
BLRM_t::construct_and_partial_factor
(dsep, dupd, F11elem, F12elem, F21elem, F22elem,
F11blr_, F12blr_, F21blr_, F22blr_,
sep_tiles_, upd_tiles_, admissibility_, opts.BLR_options());
sep_tiles_, upd_tiles_, admissibility_, blr_opts);
}
if (lchild_) lchild_->release_work_memory();
if (rchild_) rchild_->release_work_memory();
Expand Down
10 changes: 8 additions & 2 deletions src/sparse/fronts/FrontalMatrixBLRMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,16 @@ namespace strumpack {
if (lchild_) lchild_->release_work_memory();
if (rchild_) rchild_->release_work_memory();
if (dim_sep() && grid2d().active()) {
auto nF11 = F11blr_.normF();
auto nF12 = F12blr_.normF();
auto nF21 = F21blr_.normF();
auto nF = std::sqrt(nF11*nF11 + nF12*nF12 + nF21*nF21);
auto lopts = opts.BLR_options();
lopts.set_abs_tol(lopts.abs_tol() * nF);
if (dim_upd())
piv_ = BLRMPI_t::partial_factor
(F11blr_, F12blr_, F21blr_, F22blr_, adm_, opts.BLR_options());
else piv_ = F11blr_.factor(adm_, opts.BLR_options());
(F11blr_, F12blr_, F21blr_, F22blr_, adm_, lopts);
else piv_ = F11blr_.factor(adm_, lopts);
// TODO flops?
}
}
Expand Down

0 comments on commit d6d3218

Please sign in to comment.