Skip to content

Commit

Permalink
Reset CUDA errors after MAGMA calls.
Browse files Browse the repository at this point in the history
MAGMA can generate some errors for large problems,
while still giving correct results,
see https://bitbucket.org/icl/magma/issues/69/cudaerrorlaunchoutofresources-in
  • Loading branch information
pghysels committed Jul 10, 2023
1 parent 9eca60e commit df2da3f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 35 deletions.
6 changes: 5 additions & 1 deletion src/dense/CUDAWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@ namespace strumpack {

void init();

inline void gpu_peek_at_last_error() {
inline void peek_at_last_error() {
gpu_check(cudaPeekAtLastError());
}

inline void get_last_error() {
cudaGetLastError();
}

inline void synchronize() {
gpu_check(cudaDeviceSynchronize());
}
Expand Down
8 changes: 6 additions & 2 deletions src/dense/HIPWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

namespace strumpack {
namespace gpu {
7

const unsigned int MAX_BLOCKS_Y = 65535;
const unsigned int MAX_BLOCKS_Z = 65535;

Expand All @@ -59,10 +59,14 @@ namespace strumpack {

void init();

inline void gpu_peek_at_last_error() {
inline void peek_at_last_error() {
gpu_check(hipPeekAtLastError());
}

inline void get_last_error() {
hipGetLastError();
}

inline void synchronize() {
gpu_check(hipDeviceSynchronize());
}
Expand Down
64 changes: 32 additions & 32 deletions src/dense/MAGMAWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ namespace strumpack {
(m, n, max_m, max_n, max_minmn, max_mxn,
dA_array, ldda, dipiv_array, info_array,
work, lwork, batchCount, queue);
if (info)
std::cerr << "ERROR: magma_sgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
gpu_peek_at_last_error();
return info;
if (info)
std::cerr << "ERROR: magma_sgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
get_last_error();
return info;
}
inline magma_int_t getrf_vbatched_max_nocheck_work
(magma_int_t* m, magma_int_t* n,
Expand All @@ -148,11 +148,11 @@ namespace strumpack {
(m, n, max_m, max_n, max_minmn, max_mxn,
dA_array, ldda, dipiv_array, info_array,
work, lwork, batchCount, queue);
if (info)
std::cerr << "ERROR: magma_dgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
gpu_peek_at_last_error();
return info;
if (info)
std::cerr << "ERROR: magma_dgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
get_last_error();
return info;
}
inline magma_int_t getrf_vbatched_max_nocheck_work
(magma_int_t* m, magma_int_t* n,
Expand All @@ -167,11 +167,11 @@ namespace strumpack {
(m, n, max_m, max_n, max_minmn, max_mxn,
(magmaFloatComplex**)dA_array, ldda,
dipiv_array, info_array, work, lwork, batchCount, queue);
if (info)
std::cerr << "ERROR: magma_cgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
gpu_peek_at_last_error();
return info;
if (info)
std::cerr << "ERROR: magma_cgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
get_last_error();
return info;
}
inline magma_int_t getrf_vbatched_max_nocheck_work
(magma_int_t* m, magma_int_t* n,
Expand All @@ -186,11 +186,11 @@ namespace strumpack {
(m, n, max_m, max_n, max_minmn, max_mxn,
(magmaDoubleComplex**)dA_array, ldda,
dipiv_array, info_array, work, lwork, batchCount, queue);
if (info)
std::cerr << "ERROR: magma_zgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
gpu_peek_at_last_error();
return info;
if (info)
std::cerr << "ERROR: magma_zgetrf_vbatched_max_nocheck_work "
<< "failed with info= " << info << std::endl;
get_last_error();
return info;
}

inline void trsm_vbatched_max_nocheck
Expand All @@ -204,7 +204,7 @@ namespace strumpack {
magmablas_strsm_vbatched_max_nocheck
(side, uplo, transA, diag, max_m, max_n, m, n, alpha, dA_array,
ldda, dB_array, lddb, batchCount, queue);
gpu_peek_at_last_error();
get_last_error();
}
inline void trsm_vbatched_max_nocheck
(magma_side_t side, magma_uplo_t uplo, magma_trans_t transA,
Expand All @@ -217,7 +217,7 @@ namespace strumpack {
magmablas_dtrsm_vbatched_max_nocheck
(side, uplo, transA, diag, max_m, max_n, m, n, alpha, dA_array,
ldda, dB_array, lddb, batchCount, queue);
gpu_peek_at_last_error();
get_last_error();
}
inline void trsm_vbatched_max_nocheck
(magma_side_t side, magma_uplo_t uplo, magma_trans_t transA,
Expand All @@ -233,7 +233,7 @@ namespace strumpack {
(side, uplo, transA, diag, max_m, max_n, m, n, alpha_,
(magmaFloatComplex**)dA_array, ldda,
(magmaFloatComplex**)dB_array, lddb, batchCount, queue);
gpu_peek_at_last_error();
get_last_error();
}
inline void trsm_vbatched_max_nocheck
(magma_side_t side, magma_uplo_t uplo, magma_trans_t transA,
Expand All @@ -249,7 +249,7 @@ namespace strumpack {
(side, uplo, transA, diag, max_m, max_n, m, n, alpha_,
(magmaDoubleComplex**)dA_array, ldda,
(magmaDoubleComplex**)dB_array, lddb, batchCount, queue);
gpu_peek_at_last_error();
get_last_error();
}

inline void gemm_vbatched_max_nocheck
Expand All @@ -266,7 +266,7 @@ namespace strumpack {
(transA, transB, m, n, k, alpha, dA_array, ldda,
dB_array, lddb, beta, dC_array, lddc, batchCount,
max_m, max_n, max_k, queue);
gpu_peek_at_last_error();
get_last_error();
}
inline void gemm_vbatched_max_nocheck
(magma_trans_t transA, magma_trans_t transB,
Expand All @@ -282,7 +282,7 @@ namespace strumpack {
(transA, transB, m, n, k, alpha, dA_array, ldda,
dB_array, lddb, beta, dC_array, lddc, batchCount,
max_m, max_n, max_k, queue);
gpu_peek_at_last_error();
get_last_error();
}
inline void gemm_vbatched_max_nocheck
(magma_trans_t transA, magma_trans_t transB,
Expand All @@ -304,7 +304,7 @@ namespace strumpack {
(magmaFloatComplex**)dB_array, lddb, beta_,
(magmaFloatComplex**)dC_array, lddc, batchCount,
max_m, max_n, max_k, queue);
gpu_peek_at_last_error();
get_last_error();
}
inline void gemm_vbatched_max_nocheck
(magma_trans_t transA, magma_trans_t transB,
Expand All @@ -326,7 +326,7 @@ namespace strumpack {
(magmaDoubleComplex**)dB_array, lddb, beta_,
(magmaDoubleComplex**)dC_array, lddc, batchCount,
max_m, max_n, max_k, queue);
gpu_peek_at_last_error();
get_last_error();
}

inline void gemv_vbatched_max_nocheck
Expand All @@ -342,7 +342,7 @@ namespace strumpack {
const_cast<float**>(dA_array), ldda,
const_cast<float**>(dB_array), lddb, beta,
dC_array, lddc, batchCount, max_m, max_n, queue);
gpu_peek_at_last_error();
get_last_error();
}
inline void gemv_vbatched_max_nocheck
(magma_trans_t trans, magma_int_t *m, magma_int_t *n, double alpha,
Expand All @@ -357,7 +357,7 @@ namespace strumpack {
const_cast<double**>(dA_array), ldda,
const_cast<double**>(dB_array), lddb, beta,
dC_array, lddc, batchCount, max_m, max_n, queue);
gpu_peek_at_last_error();
get_last_error();
}
inline void gemv_vbatched_max_nocheck
(magma_trans_t trans, magma_int_t *m, magma_int_t *n, std::complex<float> alpha,
Expand All @@ -376,7 +376,7 @@ namespace strumpack {
(magmaFloatComplex**)(const_cast<std::complex<float>**>(dB_array)), lddb, beta_,
(magmaFloatComplex**)dC_array, lddc, batchCount,
max_m, max_n, queue);
gpu_peek_at_last_error();
get_last_error();
}
inline void gemv_vbatched_max_nocheck
(magma_trans_t trans, magma_int_t *m, magma_int_t *n,
Expand All @@ -396,7 +396,7 @@ namespace strumpack {
(magmaDoubleComplex**)(const_cast<std::complex<double>**>(dB_array)), lddb, beta_,
(magmaDoubleComplex**)dC_array, lddc, batchCount,
max_m, max_n, queue);
gpu_peek_at_last_error();
get_last_error();
}

} // end namespace magma
Expand Down

0 comments on commit df2da3f

Please sign in to comment.