Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilianbehr committed Feb 18, 2024
1 parent 41224b7 commit 933eae0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
![GitHub Release](https://img.shields.io/github/v/release/maximilianbehr/cuexpm?display_name=release&style=flat)
![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/maximilianbehr/cuexpm/total)

**Version:** 1.0.0
**Version:** 1.0.1

**Copyright:** Maximilian Behr

**License:** The software is licensed under under MIT. See [`LICENSE`](LICENSE) for details.

`cuexpm` is a `CUDA` library for the numerical approximation of the matrix exponential $e^A$.

`cuexpm` support single and double precision as well as real and complex matrices.
`cuexpm` supports single and double precision as well as real and complex matrices.


| Functions | Data |
Expand Down
31 changes: 14 additions & 17 deletions cuexpm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ struct cuexpm_traits<double> {
static constexpr double one = 1.;
static constexpr double mone = -1.;
static constexpr double zero = 0.;
static constexpr double ftwo = 2.;

/*-----------------------------------------------------------------------------
* Pade coefficients
Expand Down Expand Up @@ -157,7 +156,6 @@ struct cuexpm_traits<float> {
static constexpr float one = 1.f;
static constexpr float mone = -1.f;
static constexpr float zero = 0.f;
static constexpr float ftwo = 2.f;

/*-----------------------------------------------------------------------------
* Pade coefficients
Expand Down Expand Up @@ -206,7 +204,6 @@ struct cuexpm_traits<cuDoubleComplex> {
static constexpr cuDoubleComplex one = {1., 0.};
static constexpr cuDoubleComplex mone = {-1., 0.};
static constexpr cuDoubleComplex zero = {0., 0.};
static constexpr double ftwo = 2.;

/*-----------------------------------------------------------------------------
* Pade coefficients
Expand Down Expand Up @@ -255,7 +252,6 @@ struct cuexpm_traits<cuComplex> {
static constexpr cuComplex one = {1.f, 0.f};
static constexpr cuComplex mone = {-1.f, 0.f};
static constexpr cuComplex zero = {0.f, 0.f};
static constexpr float ftwo = 2.f;

/*-----------------------------------------------------------------------------
* Pade coefficients
Expand Down Expand Up @@ -462,8 +458,10 @@ static int cuexpm(const T *d_A, const int n, void *d_buffer, void *h_buffer, T *
/*-----------------------------------------------------------------------------
* kernel launch parameters
*-----------------------------------------------------------------------------*/
const size_t threadsPerBlock = 256;
const size_t blocksPerGrid = (n + threadsPerBlock - 1) / threadsPerBlock;
const size_t threadsPerBlock = 256; // addDiag
const size_t blocksPerGrid = (n + threadsPerBlock - 1) / threadsPerBlock; // addDiag
dim3 grid((n + 15) / 16, (n + 15) / 16); // setDiag
dim3 block(16, 16); // setDiag

/*-----------------------------------------------------------------------------
* compute the scaling parameter and Pade approximant degree
Expand Down Expand Up @@ -523,49 +521,49 @@ static int cuexpm(const T *d_A, const int n, void *d_buffer, void *h_buffer, T *
CHECK_CUDA(cudaStreamCreate(&streamV));
if (m == 3) {
// U = U + c(3)*T2 + c(1)*I
setDiag<<<blocksPerGrid, threadsPerBlock, 0, streamU>>>(U, n, cuexpm_traits<T>::Pade3[1]);
setDiag<<<grid, block, 0, streamU>>>(U, n, cuexpm_traits<T>::Pade3[1]);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUBLAS(cublasSetStream(cublasH, streamU));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, U, n, &cuexpm_traits<T>::Pade3[3], T2, n, U, n));

// V = V + c(2)*T2 + c(0)*I
setDiag<<<blocksPerGrid, threadsPerBlock, 0, streamV>>>(V, n, cuexpm_traits<T>::Pade3[0]);
setDiag<<<grid, block, 0, streamV>>>(V, n, cuexpm_traits<T>::Pade3[0]);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUBLAS(cublasSetStream(cublasH, streamV));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade3[2], T2, n, V, n));
} else if (m == 5) {
// U = U + c(5)*T4 + c(3)*T2 + c(1)*I
setDiag<<<blocksPerGrid, threadsPerBlock, 0, streamU>>>(U, n, cuexpm_traits<T>::Pade5[1]);
setDiag<<<grid, block, 0, streamU>>>(U, n, cuexpm_traits<T>::Pade5[1]);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUBLAS(cublasSetStream(cublasH, streamU));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, U, n, &cuexpm_traits<T>::Pade5[3], T2, n, U, n));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, U, n, &cuexpm_traits<T>::Pade5[5], T4, n, U, n));

// V = V + c(4)*T4 + c(2)*T2 + c(0)*I
setDiag<<<blocksPerGrid, threadsPerBlock, 0, streamV>>>(V, n, cuexpm_traits<T>::Pade5[0]);
setDiag<<<grid, block, 0, streamV>>>(V, n, cuexpm_traits<T>::Pade5[0]);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUBLAS(cublasSetStream(cublasH, streamV));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade5[2], T2, n, V, n));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade5[4], T4, n, V, n));
} else if (m == 7) {
// U = U + c(7)*T6 + c(5)*T4 + c(3)*T2 + c(1)*I
setDiag<<<blocksPerGrid, threadsPerBlock, 0, streamU>>>(U, n, cuexpm_traits<T>::Pade7[1]);
setDiag<<<grid, block, 0, streamU>>>(U, n, cuexpm_traits<T>::Pade7[1]);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUBLAS(cublasSetStream(cublasH, streamU));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, U, n, &cuexpm_traits<T>::Pade7[3], T2, n, U, n));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, U, n, &cuexpm_traits<T>::Pade7[5], T4, n, U, n));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, U, n, &cuexpm_traits<T>::Pade7[7], T6, n, U, n));

// V = V + c(6)*T6 + c(4)*T4 + c(2)*T2 + c(0)*I
setDiag<<<blocksPerGrid, threadsPerBlock, 0, streamV>>>(V, n, cuexpm_traits<T>::Pade7[0]);
setDiag<<<grid, block, 0, streamV>>>(V, n, cuexpm_traits<T>::Pade7[0]);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUBLAS(cublasSetStream(cublasH, streamV));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade7[2], T2, n, V, n));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade7[4], T4, n, V, n));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade7[6], T6, n, V, n));
} else if (m == 9) {
// U = U + c(9)*T8 + c(7)*T6 + c(5)*T4 + c(3)*T2 + c(1)*I
setDiag<<<blocksPerGrid, threadsPerBlock, 0, streamU>>>(U, n, cuexpm_traits<T>::Pade9[1]);
setDiag<<<grid, block, 0, streamU>>>(U, n, cuexpm_traits<T>::Pade9[1]);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUBLAS(cublasSetStream(cublasH, streamU));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, U, n, &cuexpm_traits<T>::Pade9[3], T2, n, U, n));
Expand All @@ -574,17 +572,16 @@ static int cuexpm(const T *d_A, const int n, void *d_buffer, void *h_buffer, T *
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, U, n, &cuexpm_traits<T>::Pade9[9], T8, n, U, n));

// V = V + c(6)*T6 + c(4)*T4 + c(2)*T2 + c(0)*I
setDiag<<<blocksPerGrid, threadsPerBlock, 0, streamV>>>(V, n, cuexpm_traits<T>::Pade9[0]);
setDiag<<<grid, block, 0, streamV>>>(V, n, cuexpm_traits<T>::Pade9[0]);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUBLAS(cublasSetStream(cublasH, streamV));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade9[2], T2, n, V, n));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade9[4], T4, n, V, n));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade9[6], T6, n, V, n));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade9[8], T8, n, V, n));
} else if (m == 13) {
dim3 grid((n + 15) / 16, (n + 15) / 16);
// U = T6*(c(13)*T6 + c(11)*T4 + c(9)*T2) + c(7)*T6 + c(5)*T4 + c(3)*T2 + c(1)*I;
setDiag<<<grid, dim3(16, 16), 0, streamU>>>(U, n, cuexpm_traits<T>::Pade13[1]);
setDiag<<<grid, block, 0, streamU>>>(U, n, cuexpm_traits<T>::Pade13[1]);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUBLAS(cublasSetStream(cublasH, streamU));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, U, n, &cuexpm_traits<T>::Pade13[3], T2, n, U, n));
Expand All @@ -595,7 +592,7 @@ static int cuexpm(const T *d_A, const int n, void *d_buffer, void *h_buffer, T *
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgemm(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, &cuexpm_traits<T>::one, T6, n, T8, n, &cuexpm_traits<T>::one, U, n));

// V = T6*(c(12)*T6 + c(10)*T4 + c(8)*T2) + c(6)*T6 + c(4)*T4 + c(2)*T2 + c(0)*I;
setDiag<<<grid, dim3(16, 16), 0, streamV>>>(V, n, cuexpm_traits<T>::Pade13[0]);
setDiag<<<grid, block, 0, streamV>>>(V, n, cuexpm_traits<T>::Pade13[0]);
CHECK_CUDA(cudaPeekAtLastError());
CHECK_CUBLAS(cublasSetStream(cublasH, streamV));
CHECK_CUBLAS(cuexpm_traits<T>::cublasXgeam(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, n, &cuexpm_traits<T>::one, V, n, &cuexpm_traits<T>::Pade13[2], T2, n, V, n));
Expand Down
4 changes: 2 additions & 2 deletions example_cuexpms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ int main(void) {
void *d_buffer = NULL; // memory buffer on the device
void *h_buffer = NULL; // memory buffer on the host

/*------------------------ -----------------------------------------------------
/*-----------------------------------------------------------------------------
* allocate A and expmA on the host
*-----------------------------------------------------------------------------*/
cudaMallocHost((void **)&A, sizeof(*A) * n * n);
Expand Down Expand Up @@ -121,6 +121,6 @@ int main(void) {
cudaFree(d_A);
cudaFree(d_expmA);
cudaFree(d_buffer);
cudaFree(h_buffer);
cudaFreeHost(h_buffer);
return 0;
}

0 comments on commit 933eae0

Please sign in to comment.