Skip to content

Commit

Permalink
In sparse direct solver MAGMA code, always use MAGMA. Add some more e…
Browse files Browse the repository at this point in the history
…rror checking.
  • Loading branch information
pghysels committed Jul 7, 2023
1 parent 0a143ba commit 9eca60e
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 17 deletions.
22 changes: 11 additions & 11 deletions src/dense/CUDAWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,24 +227,24 @@ namespace strumpack {
std::complex<double>,
DenseMatrix<std::complex<double>>&);

void getrf_buffersize
(SOLVERHandle& handle, int m, int n, float* A, int lda, int* Lwork) {
void getrf_buffersize(SOLVERHandle& handle, int m, int n,
float* A, int lda, int* Lwork) {
gpu_check(cusolverDnSgetrf_bufferSize(handle, m, n, A, lda, Lwork));
}
void getrf_buffersize
(SOLVERHandle& handle, int m, int n, double *A, int lda,
int* Lwork) {
void getrf_buffersize(SOLVERHandle& handle, int m, int n,
double *A, int lda,
int* Lwork) {
gpu_check(cusolverDnDgetrf_bufferSize(handle, m, n, A, lda, Lwork));
}
void getrf_buffersize
(SOLVERHandle& handle, int m, int n, std::complex<float>* A, int lda,
int *Lwork) {
void getrf_buffersize(SOLVERHandle& handle, int m, int n,
std::complex<float>* A, int lda,
int *Lwork) {
gpu_check(cusolverDnCgetrf_bufferSize
(handle, m, n, reinterpret_cast<cuComplex*>(A), lda, Lwork));
}
void getrf_buffersize
(SOLVERHandle& handle, int m, int n, std::complex<double>* A, int lda,
int *Lwork) {
void getrf_buffersize(SOLVERHandle& handle, int m, int n,
std::complex<double>* A, int lda,
int *Lwork) {
gpu_check(cusolverDnZgetrf_bufferSize
(handle, m, n,
reinterpret_cast<cuDoubleComplex*>(A), lda, Lwork));
Expand Down
4 changes: 4 additions & 0 deletions src/dense/CUDAWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ namespace strumpack {

void init();

inline void gpu_peek_at_last_error() {
gpu_check(cudaPeekAtLastError());
}

inline void synchronize() {
gpu_check(cudaDeviceSynchronize());
}
Expand Down
6 changes: 5 additions & 1 deletion src/dense/HIPWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

namespace strumpack {
namespace gpu {

7
const unsigned int MAX_BLOCKS_Y = 65535;
const unsigned int MAX_BLOCKS_Z = 65535;

Expand All @@ -59,6 +59,10 @@ namespace strumpack {

void init();

inline void gpu_peek_at_last_error() {
gpu_check(hipPeekAtLastError());
}

inline void synchronize() {
gpu_check(hipDeviceSynchronize());
}
Expand Down
40 changes: 36 additions & 4 deletions src/dense/MAGMAWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,15 @@ namespace strumpack {
void* work, magma_int_t* lwork,
magma_int_t batchCount, magma_queue_t queue) {
if (!batchCount) return 0;
return magma_sgetrf_vbatched_max_nocheck_work
auto info = magma_sgetrf_vbatched_max_nocheck_work
(m, n, max_m, max_n, max_minmn, max_mxn,
dA_array, ldda, dipiv_array, info_array,
work, lwork, batchCount, queue);
if (info)
std::cerr << "ERROR: magma_sgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
gpu_peek_at_last_error();
return info;
}
inline magma_int_t getrf_vbatched_max_nocheck_work
(magma_int_t* m, magma_int_t* n,
Expand All @@ -139,10 +144,15 @@ namespace strumpack {
void* work, magma_int_t* lwork,
magma_int_t batchCount, magma_queue_t queue) {
if (!batchCount) return 0;
return magma_dgetrf_vbatched_max_nocheck_work
auto info = magma_dgetrf_vbatched_max_nocheck_work
(m, n, max_m, max_n, max_minmn, max_mxn,
dA_array, ldda, dipiv_array, info_array,
work, lwork, batchCount, queue);
if (info)
std::cerr << "ERROR: magma_dgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
gpu_peek_at_last_error();
return info;
}
inline magma_int_t getrf_vbatched_max_nocheck_work
(magma_int_t* m, magma_int_t* n,
Expand All @@ -153,10 +163,15 @@ namespace strumpack {
void* work, magma_int_t* lwork,
magma_int_t batchCount, magma_queue_t queue) {
if (!batchCount) return 0;
return magma_cgetrf_vbatched_max_nocheck_work
auto info = magma_cgetrf_vbatched_max_nocheck_work
(m, n, max_m, max_n, max_minmn, max_mxn,
(magmaFloatComplex**)dA_array, ldda,
dipiv_array, info_array, work, lwork, batchCount, queue);
if (info)
std::cerr << "ERROR: magma_cgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
gpu_peek_at_last_error();
return info;
}
inline magma_int_t getrf_vbatched_max_nocheck_work
(magma_int_t* m, magma_int_t* n,
Expand All @@ -167,10 +182,15 @@ namespace strumpack {
void* work, magma_int_t* lwork,
magma_int_t batchCount, magma_queue_t queue) {
if (!batchCount) return 0;
return magma_zgetrf_vbatched_max_nocheck_work
auto info = magma_zgetrf_vbatched_max_nocheck_work
(m, n, max_m, max_n, max_minmn, max_mxn,
(magmaDoubleComplex**)dA_array, ldda,
dipiv_array, info_array, work, lwork, batchCount, queue);
if (info)
std::cerr << "ERROR: magma_zgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
gpu_peek_at_last_error();
return info;
}

inline void trsm_vbatched_max_nocheck
Expand All @@ -184,6 +204,7 @@ namespace strumpack {
magmablas_strsm_vbatched_max_nocheck
(side, uplo, transA, diag, max_m, max_n, m, n, alpha, dA_array,
ldda, dB_array, lddb, batchCount, queue);
gpu_peek_at_last_error();
}
inline void trsm_vbatched_max_nocheck
(magma_side_t side, magma_uplo_t uplo, magma_trans_t transA,
Expand All @@ -196,6 +217,7 @@ namespace strumpack {
magmablas_dtrsm_vbatched_max_nocheck
(side, uplo, transA, diag, max_m, max_n, m, n, alpha, dA_array,
ldda, dB_array, lddb, batchCount, queue);
gpu_peek_at_last_error();
}
inline void trsm_vbatched_max_nocheck
(magma_side_t side, magma_uplo_t uplo, magma_trans_t transA,
Expand All @@ -211,6 +233,7 @@ namespace strumpack {
(side, uplo, transA, diag, max_m, max_n, m, n, alpha_,
(magmaFloatComplex**)dA_array, ldda,
(magmaFloatComplex**)dB_array, lddb, batchCount, queue);
gpu_peek_at_last_error();
}
inline void trsm_vbatched_max_nocheck
(magma_side_t side, magma_uplo_t uplo, magma_trans_t transA,
Expand All @@ -226,6 +249,7 @@ namespace strumpack {
(side, uplo, transA, diag, max_m, max_n, m, n, alpha_,
(magmaDoubleComplex**)dA_array, ldda,
(magmaDoubleComplex**)dB_array, lddb, batchCount, queue);
gpu_peek_at_last_error();
}

inline void gemm_vbatched_max_nocheck
Expand All @@ -242,6 +266,7 @@ namespace strumpack {
(transA, transB, m, n, k, alpha, dA_array, ldda,
dB_array, lddb, beta, dC_array, lddc, batchCount,
max_m, max_n, max_k, queue);
gpu_peek_at_last_error();
}
inline void gemm_vbatched_max_nocheck
(magma_trans_t transA, magma_trans_t transB,
Expand All @@ -257,6 +282,7 @@ namespace strumpack {
(transA, transB, m, n, k, alpha, dA_array, ldda,
dB_array, lddb, beta, dC_array, lddc, batchCount,
max_m, max_n, max_k, queue);
gpu_peek_at_last_error();
}
inline void gemm_vbatched_max_nocheck
(magma_trans_t transA, magma_trans_t transB,
Expand All @@ -278,6 +304,7 @@ namespace strumpack {
(magmaFloatComplex**)dB_array, lddb, beta_,
(magmaFloatComplex**)dC_array, lddc, batchCount,
max_m, max_n, max_k, queue);
gpu_peek_at_last_error();
}
inline void gemm_vbatched_max_nocheck
(magma_trans_t transA, magma_trans_t transB,
Expand All @@ -299,6 +326,7 @@ namespace strumpack {
(magmaDoubleComplex**)dB_array, lddb, beta_,
(magmaDoubleComplex**)dC_array, lddc, batchCount,
max_m, max_n, max_k, queue);
gpu_peek_at_last_error();
}

inline void gemv_vbatched_max_nocheck
Expand All @@ -314,6 +342,7 @@ namespace strumpack {
const_cast<float**>(dA_array), ldda,
const_cast<float**>(dB_array), lddb, beta,
dC_array, lddc, batchCount, max_m, max_n, queue);
gpu_peek_at_last_error();
}
inline void gemv_vbatched_max_nocheck
(magma_trans_t trans, magma_int_t *m, magma_int_t *n, double alpha,
Expand All @@ -328,6 +357,7 @@ namespace strumpack {
const_cast<double**>(dA_array), ldda,
const_cast<double**>(dB_array), lddb, beta,
dC_array, lddc, batchCount, max_m, max_n, queue);
gpu_peek_at_last_error();
}
inline void gemv_vbatched_max_nocheck
(magma_trans_t trans, magma_int_t *m, magma_int_t *n, std::complex<float> alpha,
Expand All @@ -346,6 +376,7 @@ namespace strumpack {
(magmaFloatComplex**)(const_cast<std::complex<float>**>(dB_array)), lddb, beta_,
(magmaFloatComplex**)dC_array, lddc, batchCount,
max_m, max_n, queue);
gpu_peek_at_last_error();
}
inline void gemv_vbatched_max_nocheck
(magma_trans_t trans, magma_int_t *m, magma_int_t *n,
Expand All @@ -365,6 +396,7 @@ namespace strumpack {
(magmaDoubleComplex**)(const_cast<std::complex<double>**>(dB_array)), lddb, beta_,
(magmaDoubleComplex**)dC_array, lddc, batchCount,
max_m, max_n, queue);
gpu_peek_at_last_error();
}

} // end namespace magma
Expand Down
2 changes: 1 addition & 1 deletion src/sparse/fronts/FrontalMatrixMAGMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ namespace strumpack {
}

// static const int FRONT_SMALL = 10000000;
static const int MIN_BATCH_COUNT = 8;
static const int MIN_BATCH_COUNT = 0;
std::vector<FM_t*> f;
std::size_t factor_size = 0, Schur_size = 0, piv_size = 0,
total_upd_size = 0, work_bytes, ea_bytes,
Expand Down

0 comments on commit 9eca60e

Please sign in to comment.