diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index ed7287cee..f3ba741be 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -86,20 +86,20 @@ void bli_cntx_init_zen( cntx_t* cntx ) ); #endif - // Update the context with optimized level-1f kernels. - bli_cntx_set_l1f_kers - ( - 4, - - // axpyf - BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, - BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, - - // dotxf - BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, - cntx - ); + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 6, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + cntx + ); // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 0964ce463..84528ec5d 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -84,20 +84,20 @@ void bli_cntx_init_zen2( cntx_t* cntx ) ); #endif - // Update the context with optimized level-1f kernels. - bli_cntx_set_l1f_kers - ( - 4, - - // axpyf - BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, - - // dotxf - BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, - cntx - ); + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 6, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + cntx + ); // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index b5bbb05ed..65a4f008d 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -4,6 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. + Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without @@ -36,6 +37,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) { + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; blksz_t thresh[ BLIS_NUM_THRESH ]; @@ -92,19 +94,21 @@ void bli_cntx_init_zen3( cntx_t* cntx ) ); #endif - // Update the context with optimized level-1f kernels. - bli_cntx_set_l1f_kers - ( - 4, - // axpyf - BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, - // dotxf - BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, - BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, - cntx - ); - + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 6, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + cntx + ); + // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( @@ -295,4 +299,3 @@ void bli_cntx_init_zen3( cntx_t* cntx ) cntx ); } - diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index fe7702e4c..ac20c33b0 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -51,81 +51,371 @@ void PASTEMAC(ch,varname) \ cntx_t* cntx \ ) \ { \ - const num_t dt = PASTEMAC(ch,type); \ \ - ctype* zero = PASTEMAC(ch,0); \ - ctype* A1; \ - ctype* x1; \ - ctype* y1; \ - dim_t i; \ - dim_t b_fuse, f; \ - dim_t n_elem, n_iter; \ - inc_t rs_at, cs_at; \ - conj_t conja; \ + const num_t dt = PASTEMAC(ch,type); \ \ - bli_set_dims_incs_with_trans( transa, \ - m, n, rs_a, cs_a, \ - &n_elem, &n_iter, &rs_at, &cs_at ); \ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A1; \ + ctype* x1; \ + ctype* y1; \ + dim_t i; \ + dim_t b_fuse, f; \ + dim_t n_elem, n_iter; \ + inc_t rs_at, cs_at; \ + conj_t conja; \ \ - conja = bli_extract_conj( transa ); \ + bli_set_dims_incs_with_trans( transa, \ + m, n, rs_a, cs_a, \ + &n_elem, &n_iter, &rs_at, &cs_at ); \ \ - /* If beta is zero, use setv. Otherwise, scale by beta. */ \ - if ( PASTEMAC(ch,eq0)( *beta ) ) \ - { \ - /* y = 0; */ \ - PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - n_elem, \ - zero, \ - y, incy, \ - cntx, \ - NULL \ - ); \ - } \ - else \ - { \ - /* y = beta * y; */ \ - PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - n_elem, \ - beta, \ - y, incy, \ - cntx, \ - NULL \ - ); \ - } \ + conja = bli_extract_conj( transa ); \ \ - PASTECH(ch,axpyf_ker_ft) kfp_af; \ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ \ - /* Query the context for the kernel function pointer and fusing factor. */ \ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ + PASTECH(ch,axpyf_ker_ft) kfp_af; \ \ - for ( i = 0; i < n_iter; i += f ) \ - { \ - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); \ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ \ - A1 = a + (0 )*rs_at + (i )*cs_at; \ - x1 = x + (i )*incx; \ - y1 = y + (0 )*incy; \ + for ( i = 0; i < n_iter; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); \ \ - /* y = y + alpha * A1 * x1; */ \ - kfp_af \ - ( \ - conja, \ - conjx, \ - n_elem, \ - f, \ - alpha, \ - A1, rs_at, cs_at, \ - x1, incx, \ - y1, incy, \ - cntx \ - ); \ - } \ + A1 = a + (0 )*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + y1 = y + (0 )*incy; \ +\ + /* y = y + alpha * A1 * x1; */ \ + kfp_af \ + ( \ + conja, \ + conjx, \ + n_elem, \ + f, \ + alpha, \ + A1, rs_at, cs_at, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ + } \ } -INSERT_GENTFUNC_BASIC0( gemv_unf_var2 ) +#ifdef BLIS_CONFIG_EPYC + +void bli_dgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + + double* A1; + double* x1; + double* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + NULL + ); + + /* Query the context for the kernel function pointer and fusing factor. */ + b_fuse = 5; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_daxpyf_zen_int_5 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + NULL + ); + } +} + +void bli_sgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + float* beta, + float* y, inc_t incy, + cntx_t* cntx + ) +{ + + float* A1; + float* x1; + float* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + + bli_sscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + NULL + ); + + /* Query the context for the kernel function pointer and fusing factor. */ + b_fuse = 5; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_saxpyf_zen_int_5 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + NULL + ); + } +} + +void bli_zgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + dcomplex* alpha, + dcomplex* a, inc_t rs_a, inc_t cs_a, + dcomplex* x, inc_t incx, + dcomplex* beta, + dcomplex* y, inc_t incy, + cntx_t* cntx + ) +{ + + dcomplex* A1; + dcomplex* x1; + dcomplex* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + + bli_zscalv_ex + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + + if( bli_zeq0( *alpha ) ) + { + return; + } + + /* fusing factor */ + b_fuse = 4; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_zaxpyf_zen_int_4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + NULL + ); + } +} + +void bli_cgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + scomplex* alpha, + scomplex* a, inc_t rs_a, inc_t cs_a, + scomplex* x, inc_t incx, + scomplex* beta, + scomplex* y, inc_t incy, + cntx_t* cntx + ) +{ + + scomplex* A1; + scomplex* x1; + scomplex* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + bli_cscalv_ex + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + + /* fusing factor. */ + b_fuse = 5; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_caxpyf_zen_int_5 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + NULL + ); + } +} + + +#else +INSERT_GENTFUNC_BASIC0( gemv_unf_var2 ) +#endif diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index 85c65dde4..63502c0e9 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -137,6 +137,653 @@ void PASTEF77(ch,blasname) \ } #ifdef BLIS_ENABLE_BLAS +#ifdef BLIS_CONFIG_EPYC +void dgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + const double* x, const f77_int* incx, + const double* beta, + double* y, const f77_int* incy + ) +{ + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(d), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) { + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if ( *m < 0 ) m0 = ( dim_t )0; + else m0 = ( dim_t )(*m); + + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if ( bli_does_notrans( blis_transa ) ) { m_y = m0; n_x = n0; } + else { m_y = n0; n_x = m0; } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + if ( m_y > 0 && n_x == 0 ) + { + /* Finalize BLIS. */ + // bli_finalize_auto(); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + x0 = ((double*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((double*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + /* Call variants based on transpose value. */ + if(bli_does_notrans(blis_transa)) + { + //variant_2 is chosen for column-storage + // and uses axpyf-based implementation + bli_dgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); + } + else + { + //var_1 is chosen for row-storage + //and uses dotxf-based implementation + bli_dgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); + + } + +} + +void sgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const float* alpha, + const float* a, const f77_int* lda, + const float* x, const f77_int* incx, + const float* beta, + float* y, const f77_int* incy + ) +{ + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(s), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) { + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if ( *m < 0 ) m0 = ( dim_t )0; + else m0 = ( dim_t )(*m); + + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if ( bli_does_notrans( blis_transa ) ) { m_y = m0; n_x = n0; } + else { m_y = n0; n_x = m0; } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + if ( m_y > 0 && n_x == 0 ) + { + /* Finalize BLIS. */ + // bli_finalize_auto(); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + x0 = ((float*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((float*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + /* Call variants based on transpose value. */ + if(bli_does_notrans(blis_transa)) + { + bli_sgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL + ); + + } + else + { + bli_sgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL + ); + + } +} + + +void cgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const scomplex* alpha, + const scomplex* a, const f77_int* lda, + const scomplex* x, const f77_int* incx, + const scomplex* beta, + scomplex* y, const f77_int* incy + ) +{ + + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(c), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) { + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + // bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if( *m < 0 ) m0 = (dim_t)0; + else m0 = (dim_t)(*m); + + if( *n < 0 ) n0 = (dim_t)0; + else n0 = (dim_t)(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } + else { m_y = n0; n_x = m0; } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + + if ( m_y > 0 && n_x == 0 ) + { + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if( *incx < 0 ) + { + x0 = ((scomplex*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((scomplex*)x); + incx0 = (inc_t)(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + if( m_y == 1 ) + { + conj_t conja = bli_extract_conj(blis_transa); + scomplex rho; + bli_cdotv_zen_int5 + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL + ); + scomplex yval = *y0; + if(!bli_ceq0(*beta)) + { + bli_cscals( *beta, yval ); + } + else + { + bli_csetsc( 0.0, 0.0, &yval); + } + if(!bli_ceq0(*alpha)) + { + bli_caxpys( *alpha, rho, yval); + } + y0->real = yval.real; + y0->imag = yval.imag; + + return; + } + + /* call variants based on transpose value */ + if( bli_does_notrans( blis_transa ) ) + { + bli_cgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_cgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL + ); + } + +} + + +void zgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* x, const f77_int* incx, + const dcomplex* beta, + dcomplex* y, const f77_int* incy + ) +{ + + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(z), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) { + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + // bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if( *m < 0 ) m0 = (dim_t)0; + else m0 = (dim_t)(*m); + + if( *n < 0 ) n0 = (dim_t)0; + else n0 = (dim_t)(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } + else { m_y = n0; n_x = m0; } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + + if ( m_y > 0 && n_x == 0 ) + { + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if( *incx < 0 ) + { + x0 = ((dcomplex*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((dcomplex*)x); + incx0 = (inc_t)(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + if( m_y == 1 ) + { + conj_t conja = bli_extract_conj(blis_transa); + dcomplex rho; + + bli_zdotv_zen_int5 + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL + ); + + dcomplex yval = *y0; + if(!bli_zeq0(*beta)) + { + bli_zscals( *beta, yval ); + } + else + { + bli_zsetsc( 0.0, 0.0, &yval); + } + if(!bli_zeq0(*alpha)) + { + bli_zaxpys( *alpha, rho, yval); + } + y0->real = yval.real; + y0->imag = yval.imag; + + return; + } + + /* call variants based on transpose value */ + if( bli_does_notrans( blis_transa ) ) + { + bli_zgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_zgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL + ); + } + +} + + +#else INSERT_GENTFUNC_BLAS( gemv, gemv ) #endif - +#endif \ No newline at end of file diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index 011ebcdfb..e3e62e769 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -55,6 +55,19 @@ GENTFUNC( double, d, blasname, blisname ) \ GENTFUNC( scomplex, c, blasname, blisname ) \ GENTFUNC( dcomplex, z, blasname, blisname ) +<<<<<<< HEAD +======= +#define INSERT_GENTFUNC_BLAS_SC( blasname, blisname ) \ +\ +GENTFUNC( float, s, blasname, blisname ) \ +GENTFUNC( scomplex, c, blasname, blisname ) + + +#define INSERT_GENTFUNC_BLAS_CZ( blasname, blisname ) \ +\ +GENTFUNC( scomplex, c, blasname, blisname ) \ +GENTFUNC( dcomplex, z, blasname, blisname ) +>>>>>>> 2e1a5bc1d... Optimized double complex axpyf kernel for zgemv // -- Basic one-operand macro with real domain only -- diff --git a/kernels/zen/1f/bli_axpyf_zen_int_4.c b/kernels/zen/1f/bli_axpyf_zen_int_4.c new file mode 100644 index 000000000..f5a043db8 --- /dev/null +++ b/kernels/zen/1f/bli_axpyf_zen_int_4.c @@ -0,0 +1,575 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* Union data structure to access AVX registers + One 256-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/* Union data structure to access AVX registers +* One 256-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + + +void bli_caxpyf_zen_int_4 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + inc_t fuse_fac = 4; + inc_t i; + + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm10; + __m256 ymm12, ymm13; + + float* ap[4]; + float* y0 = (float*)y; + + scomplex chi0; + scomplex chi1; + scomplex chi2; + scomplex chi3; + + + dim_t setPlusOne = 1; + + if ( bli_is_conj(conja) ) + { + setPlusOne = -1; + } + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_ceq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { +#ifdef BLIS_CONFIG_EPYC + for ( i = 0; i < b_n; ++i ) + { + scomplex* a1 = a + (0 )*inca + (i )*lda; + scomplex* chi1 = x + (i )*incx; + scomplex* y1 = y + (0 )*incy; + scomplex alpha_chi1; + + bli_ccopycjs( conjx, *chi1, alpha_chi1 ); + bli_cscals( *alpha, alpha_chi1 ); + + bli_caxpyv_zen_int5 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#else + caxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + scomplex* a1 = a + (0 )*inca + (i )*lda; + scomplex* chi1 = x + (i )*incx; + scomplex* y1 = y + (0 )*incy; + scomplex alpha_chi1; + + bli_ccopycjs( conjx, *chi1, alpha_chi1 ); + bli_cscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#endif + return; + } + + + // At this point, we know that b_n is exactly equal to the fusing factor. + if(bli_is_noconj(conjx)) + { + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + } + else + { + scomplex *pchi0 = x + 0*incx ; + scomplex *pchi1 = x + 1*incx ; + scomplex *pchi2 = x + 2*incx ; + scomplex *pchi3 = x + 3*incx ; + + bli_ccopycjs( conjx, *pchi0, chi0 ); + bli_ccopycjs( conjx, *pchi1, chi1 ); + bli_ccopycjs( conjx, *pchi2, chi2 ); + bli_ccopycjs( conjx, *pchi3, chi3 ); + } + + // Scale each chi scalar by alpha. + bli_cscals( *alpha, chi0 ); + bli_cscals( *alpha, chi1 ); + bli_cscals( *alpha, chi2 ); + bli_cscals( *alpha, chi3 ); + + lda *= 2; + incx *= 2; + incy *= 2; + inca *= 2; + + ap[0] = (float*)a; + ap[1] = (float*)a + lda; + ap[2] = ap[1] + lda; + ap[3] = ap[2] + lda; + + if( inca == 2 && incy == 2 ) + { + inc_t n1 = m >> 2;// div by 4 + inc_t n2 = m & 3;// mod by 4 + + ymm12 = _mm256_setzero_ps(); + ymm13 = _mm256_setzero_ps(); + + // broadcast real & imag parts of 4 elements of x + ymm0 = _mm256_broadcast_ss(&chi0.real); // real part of x0 + ymm1 = _mm256_broadcast_ss(&chi0.imag); // imag part of x0 + ymm2 = _mm256_broadcast_ss(&chi1.real); // real part of x1 + ymm3 = _mm256_broadcast_ss(&chi1.imag); // imag part of x1 + ymm4 = _mm256_broadcast_ss(&chi2.real); // real part of x2 + ymm5 = _mm256_broadcast_ss(&chi2.imag); // imag part of x2 + ymm6 = _mm256_broadcast_ss(&chi3.real); // real part of x3 + ymm7 = _mm256_broadcast_ss(&chi3.imag); // imag part of x3 + + for(i = 0; i < n1; i++) + { + //load first two columns of A + ymm8 = _mm256_loadu_ps(ap[0] + 0); + ymm10 = _mm256_loadu_ps(ap[1] + 0); + + ymm12 = _mm256_mul_ps(ymm8, ymm0); + ymm13 = _mm256_mul_ps(ymm8, ymm1); + + ymm12 = _mm256_fmadd_ps(ymm10, ymm2, ymm12); + ymm13 = _mm256_fmadd_ps(ymm10, ymm3, ymm13); + + //load 3rd and 4th columns of A + ymm8 = _mm256_loadu_ps(ap[2] + 0); + ymm10 = _mm256_loadu_ps(ap[3] + 0); + + ymm12 = _mm256_fmadd_ps(ymm8, ymm4, ymm12); + ymm13 = _mm256_fmadd_ps(ymm8, ymm5, ymm13); + + ymm12 = _mm256_fmadd_ps(ymm10, ymm6, ymm12); + ymm13 = _mm256_fmadd_ps(ymm10, ymm7, ymm13); + + //load Y vector + ymm10 = _mm256_loadu_ps(y0 + 0); + + if(bli_is_noconj(conja)) + { + //printf("Inside no conj if\n"); + ymm13 = _mm256_permute_ps(ymm13, 0xB1); + ymm8 = _mm256_addsub_ps(ymm12, ymm13); + } + else + { + ymm12 = _mm256_permute_ps(ymm12, 0xB1); + ymm8 = _mm256_addsub_ps(ymm13, ymm12); + ymm8 = _mm256_permute_ps(ymm8, 0xB1); + } + + ymm12 = _mm256_add_ps(ymm8, ymm10); + + _mm256_storeu_ps((float*)(y0), ymm12); + + y0 += 8; + ap[0] += 8; + ap[1] += 8; + ap[2] += 8; + ap[3] += 8; + } + + // If there are leftover iterations, perform them with scalar code. + + for ( i = 0; (i + 0) < n2 ; ++i ) + { + + scomplex y0c = *(scomplex*)y0; + + const scomplex a0c = *(scomplex*)ap[0]; + const scomplex a1c = *(scomplex*)ap[1]; + const scomplex a2c = *(scomplex*)ap[2]; + const scomplex a3c = *(scomplex*)ap[3]; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + + *(scomplex*)y0 = y0c; + + ap[0] += 2; + ap[1] += 2; + ap[2] += 2; + ap[3] += 2; + y0 += 2; + } + //PASTEMAC(c,fprintm)(stdout, "Y after A*x in axpyf",m, 1, (scomplex*)y, 1, 1, "%4.1f", ""); + + } + else + { + for (i = 0 ; (i + 0) < m ; ++i ) + { + scomplex y0c = *(scomplex*)y0; + const scomplex a0c = *(scomplex*)ap[0]; + const scomplex a1c = *(scomplex*)ap[1]; + const scomplex a2c = *(scomplex*)ap[2]; + const scomplex a3c = *(scomplex*)ap[3]; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + + *(scomplex*)y0 = y0c; + + ap[0] += inca; + ap[1] += inca; + ap[2] += inca; + ap[3] += inca; + y0 += incy; + } + } +} + + +void bli_zaxpyf_zen_int_4 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + inc_t fuse_fac = 4; + inc_t i; + + v4df_t ymm0, ymm1, ymm2, ymm3; + v4df_t ymm4, ymm5, ymm6, ymm7; + v4df_t ymm8, ymm10; + v4df_t ymm12, ymm13; + + double* ap[4]; + double* y0 = (double*)y; + + dcomplex chi0; + dcomplex chi1; + dcomplex chi2; + dcomplex chi3; + + dim_t setPlusOne = 1; + + if ( bli_is_conj(conja) ) + { + setPlusOne = -1; + } + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { +#ifdef BLIS_CONFIG_EPYC + for ( i = 0; i < b_n; ++i ) + { + dcomplex* a1 = a + (0 )*inca + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y + (0 )*incy; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + bli_zaxpyv_zen_int5 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } +#else + zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + dcomplex* a1 = a + (0 )*inca + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y + (0 )*incy; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#endif + return; + } + + + // At this point, we know that b_n is exactly equal to the fusing factor. + if(bli_is_noconj(conjx)) + { + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + } + else + { + dcomplex *pchi0 = x + 0*incx ; + dcomplex *pchi1 = x + 1*incx ; + dcomplex *pchi2 = x + 2*incx ; + dcomplex *pchi3 = x + 3*incx ; + + bli_zcopycjs( conjx, *pchi0, chi0 ); + bli_zcopycjs( conjx, *pchi1, chi1 ); + bli_zcopycjs( conjx, *pchi2, chi2 ); + bli_zcopycjs( conjx, *pchi3, chi3 ); + } + + // Scale each chi scalar by alpha. + bli_zscals( *alpha, chi0 ); + bli_zscals( *alpha, chi1 ); + bli_zscals( *alpha, chi2 ); + bli_zscals( *alpha, chi3 ); + + lda *= 2; + incx *= 2; + incy *= 2; + inca *= 2; + + ap[0] = (double*)a; + ap[1] = (double*)a + lda; + ap[2] = ap[1] + lda; + ap[3] = ap[2] + lda; + + if( inca == 2 && incy == 2 ) + { + inc_t n1 = m >> 1; // Divide by 2 + inc_t n2 = m & 1; // % 2 + + ymm12.v = _mm256_setzero_pd(); + ymm13.v = _mm256_setzero_pd(); + + // broadcast real & imag parts of 4 elements of x + ymm0.v = _mm256_broadcast_sd(&chi0.real); // real part of x0 + ymm1.v = _mm256_broadcast_sd(&chi0.imag); // imag part of x0 + ymm2.v = _mm256_broadcast_sd(&chi1.real); // real part of x1 + ymm3.v = _mm256_broadcast_sd(&chi1.imag); // imag part of x1 + ymm4.v = _mm256_broadcast_sd(&chi2.real); // real part of x2 + ymm5.v = _mm256_broadcast_sd(&chi2.imag); // imag part of x2 + ymm6.v = _mm256_broadcast_sd(&chi3.real); // real part of x3 + ymm7.v = _mm256_broadcast_sd(&chi3.imag); // imag part of x3 + + + for(i = 0; i < n1; i++) + { + //load first two columns of A + ymm8.v = _mm256_loadu_pd(ap[0] + 0); // 2 complex values form a0 + ymm10.v = _mm256_loadu_pd(ap[1] + 0); // 2 complex values form a0 + + ymm12.v = _mm256_mul_pd(ymm8.v, ymm0.v); + ymm13.v = _mm256_mul_pd(ymm8.v, ymm1.v); + + ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm2.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm3.v, ymm13.v); + + //load 3rd and 4th columns of A + ymm8.v = _mm256_loadu_pd(ap[2] + 0); + ymm10.v = _mm256_loadu_pd(ap[3] + 0); + + ymm12.v = _mm256_fmadd_pd(ymm8.v, ymm4.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm8.v, ymm5.v, ymm13.v); + + ymm12.v = _mm256_fmadd_pd(ymm10.v, ymm6.v, ymm12.v); + ymm13.v = _mm256_fmadd_pd(ymm10.v, ymm7.v, ymm13.v); + + //load Y vector + ymm10.v = _mm256_loadu_pd(y0 + 0); + + if(bli_is_noconj(conja)) + { + ymm13.v = _mm256_permute_pd(ymm13.v, 5); + ymm8.v = _mm256_addsub_pd(ymm12.v, ymm13.v); + } + else + { + ymm12.v = _mm256_permute_pd(ymm12.v, 5); + ymm8.v = _mm256_addsub_pd(ymm13.v, ymm12.v); + ymm8.v = _mm256_permute_pd(ymm8.v, 5); + } + + ymm12.v = _mm256_add_pd(ymm8.v, ymm10.v); + + _mm256_storeu_pd((double*)(y0), ymm12.v); + + y0 += 4; + ap[0] += 4; + ap[1] += 4; + ap[2] += 4; + ap[3] += 4; + } + + // If there are leftover iterations, perform them with scalar code. + + for ( i = 0; (i + 0) < n2 ; ++i ) + { + dcomplex y0c = *(dcomplex*)y0; + + const dcomplex a0c = *(dcomplex*)ap[0]; + const dcomplex a1c = *(dcomplex*)ap[1]; + const dcomplex a2c = *(dcomplex*)ap[2]; + const dcomplex a3c = *(dcomplex*)ap[3]; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + + *(dcomplex*)y0 = y0c; + + ap[0] += 2; + ap[1] += 2; + ap[2] += 2; + ap[3] += 2; + y0 += 2; + } + //PASTEMAC(c,fprintm)(stdout, "Y after A*x in axpyf",m, 1, (scomplex*)y, 1, 1, "%4.1f", ""); + + } + else + { + for (i = 0 ; (i + 0) < m ; ++i ) + { + dcomplex y0c = *(dcomplex*)y0; + const dcomplex a0c = *(dcomplex*)ap[0]; + const dcomplex a1c = *(dcomplex*)ap[1]; + const dcomplex a2c = *(dcomplex*)ap[2]; + const dcomplex a3c = *(dcomplex*)ap[3]; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + + *(dcomplex*)y0 = y0c; + + ap[0] += inca; + ap[1] += inca; + ap[2] += inca; + ap[3] += inca; + y0 += incy; + } + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 161bcef1a..57933c1b1 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -92,6 +92,12 @@ SETV_KER_PROT(double, d, setv_zen_int) AXPYF_KER_PROT( float, s, axpyf_zen_int_8 ) AXPYF_KER_PROT( double, d, axpyf_zen_int_8 ) +AXPYF_KER_PROT( float, s, axpyf_zen_int_5 ) +AXPYF_KER_PROT( double, d, axpyf_zen_int_5 ) +AXPYF_KER_PROT( scomplex, c, axpyf_zen_int_5 ) +AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_5 ) +AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) + // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) @@ -115,7 +121,7 @@ GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x8 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x8 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_6x4 ) -GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x4 ) +GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_5x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_3x4 ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x4 ) diff --git a/kernels/zen2/1f/bli_axpyf_zen_int_5.c b/kernels/zen2/1f/bli_axpyf_zen_int_5.c index 5a919b622..1b595a652 100644 --- a/kernels/zen2/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen2/1f/bli_axpyf_zen_int_5.c @@ -264,7 +264,7 @@ void bli_saxpyf_zen_int_5 a3 += n_elem_per_reg; a4 += n_elem_per_reg; } - + // If there are leftover iterations, perform them with scalar code. for ( ; (i + 0) < m ; ++i ) { @@ -316,7 +316,7 @@ void bli_saxpyf_zen_int_5 a1 += inca; a2 += inca; a3 += inca; - a4 += inca; + a4 += inca; y0 += incy; } @@ -498,8 +498,8 @@ void bli_daxpyf_zen_int_5 // Store the output. - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); y0 += n_iter_unroll * n_elem_per_reg; a0 += n_iter_unroll * n_elem_per_reg; @@ -538,7 +538,7 @@ void bli_daxpyf_zen_int_5 a3 += n_elem_per_reg; a4 += n_elem_per_reg; } - + // If there are leftover iterations, perform them with scalar code. for ( ; (i + 0) < m ; ++i ) { @@ -590,7 +590,822 @@ void bli_daxpyf_zen_int_5 a1 += inca; a2 += inca; a3 += inca; - a4 += inca; + a4 += inca; + y0 += incy; + } + + } +} + + +// ----------------------------------------------------------------------------- + +void bli_caxpyf_zen_int_5 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t fuse_fac = 5; + + const dim_t n_elem_per_reg = 4; + + dim_t i = 0; + dim_t setPlusOne = 1; + + v8sf_t chi0v, chi1v, chi2v, chi3v, chi4v; + v8sf_t chi5v, chi6v, chi7v, chi8v, chi9v; + + v8sf_t a00v, a01v, a02v, a03v, a04v; + v8sf_t a05v, a06v, a07v, a08v, a09v; +#if 0 + v8sf_t a10v, a11v, a12v, a13v, a14v; + v8sf_t a15v, a16v, a17v, a18v, a19v; + v8sf_t y1v; +#endif + v8sf_t y0v; + v8sf_t setMinus, setPlus; + + scomplex* restrict a0; + scomplex* restrict a1; + scomplex* restrict a2; + scomplex* restrict a3; + scomplex* restrict a4; + + scomplex* restrict y0; + + scomplex chi0; + scomplex chi1; + scomplex chi2; + scomplex chi3; + scomplex chi4; + + if ( bli_is_conj(conja) ){ + setPlusOne = -1; + } + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_ceq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { +#ifdef BLIS_CONFIG_EPYC + for ( i = 0; i < b_n; ++i ) + { + scomplex* a1 = a + (0 )*inca + (i )*lda; + scomplex* chi1 = x + (i )*incx; + scomplex* y1 = y + (0 )*incy; + scomplex alpha_chi1; + + bli_ccopycjs( conjx, *chi1, alpha_chi1 ); + bli_cscals( *alpha, alpha_chi1 ); + + bli_caxpyv_zen_int5 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#else + caxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + scomplex* a1 = a + (0 )*inca + (i )*lda; + scomplex* chi1 = x + (i )*incx; + scomplex* y1 = y + (0 )*incy; + scomplex alpha_chi1; + + bli_ccopycjs( conjx, *chi1, alpha_chi1 ); + bli_cscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#endif + return; + } + + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + a2 = a + 2*lda; + a3 = a + 3*lda; + a4 = a + 4*lda; + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + chi4 = *( x + 4*incx ); + + scomplex *pchi0 = x + 0*incx ; + scomplex *pchi1 = x + 1*incx ; + scomplex *pchi2 = x + 2*incx ; + scomplex *pchi3 = x + 3*incx ; + scomplex *pchi4 = x + 4*incx ; + + bli_ccopycjs( conjx, *pchi0, chi0 ); + bli_ccopycjs( conjx, *pchi1, chi1 ); + bli_ccopycjs( conjx, *pchi2, chi2 ); + bli_ccopycjs( conjx, *pchi3, chi3 ); + bli_ccopycjs( conjx, *pchi4, chi4 ); + + // Scale each chi scalar by alpha. + bli_cscals( *alpha, chi0 ); + bli_cscals( *alpha, chi1 ); + bli_cscals( *alpha, chi2 ); + bli_cscals( *alpha, chi3 ); + bli_cscals( *alpha, chi4 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_ss( &chi0.real ); + chi1v.v = _mm256_broadcast_ss( &chi1.real ); + chi2v.v = _mm256_broadcast_ss( &chi2.real ); + chi3v.v = _mm256_broadcast_ss( &chi3.real ); + chi4v.v = _mm256_broadcast_ss( &chi4.real ); + + chi5v.v = _mm256_broadcast_ss( &chi0.imag ); + chi6v.v = _mm256_broadcast_ss( &chi1.imag ); + chi7v.v = _mm256_broadcast_ss( &chi2.imag ); + chi8v.v = _mm256_broadcast_ss( &chi3.imag ); + chi9v.v = _mm256_broadcast_ss( &chi4.imag ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + setMinus.v = _mm256_set_ps( -1, 1, -1, 1, -1, 1, -1, 1 ); + + setPlus.v = _mm256_set1_ps( 1 ); + if ( bli_is_conj(conja) ){ + setPlus.v = _mm256_set_ps( -1, 1, -1, 1, -1, 1, -1, 1 ); + } + + /* + y := y + alpha * conja(A) * conjx(x) + + nn + (ar + ai) (xr + xi) + ar * xr - ai * xi + ar * xi + ai * xr + + cc : (ar - ai) (xr - xi) + ar * xr - ai * xi + -(ar * xi + ai * xr) + + nc : (ar + ai) (xr - xi) + ar * xr + ai * xi + -(ar * xi - ai * xr) + + cn : (ar - ai) (xr + xi) + ar * xr + ai * xi + ar * xi - ai * xr + + */ + + i = 0; +#if 0 //Low performance + for( i = 0; (i + 7) < m; i += 8 ) + { + // Load the input values. + y0v.v = _mm256_loadu_ps( (float*) (y0 + 0*n_elem_per_reg )); + y1v.v = _mm256_loadu_ps( (float*) (y0 + 1*n_elem_per_reg )); + + a00v.v = _mm256_loadu_ps( (float*) (a0 + 0*n_elem_per_reg )); + a10v.v = _mm256_loadu_ps( (float*) (a0 + 1*n_elem_per_reg )); + + a01v.v = _mm256_loadu_ps( (float*) (a1 + 0*n_elem_per_reg )); + a11v.v = _mm256_loadu_ps( (float*) (a1 + 1*n_elem_per_reg )); + + a02v.v = _mm256_loadu_ps( (float*) (a2 + 0*n_elem_per_reg )); + a12v.v = _mm256_loadu_ps( (float*) (a2 + 1*n_elem_per_reg )); + + a03v.v = _mm256_loadu_ps( (float*) (a3 + 0*n_elem_per_reg )); + a13v.v = _mm256_loadu_ps( (float*) (a3 + 1*n_elem_per_reg )); + + a04v.v = _mm256_loadu_ps( (float*) (a4 + 0*n_elem_per_reg )); + a14v.v = _mm256_loadu_ps( (float*) (a4 + 1*n_elem_per_reg )); + + a00v.v = _mm256_mul_ps( a00v.v, setPlus.v ); + a01v.v = _mm256_mul_ps( a01v.v, setPlus.v ); + a02v.v = _mm256_mul_ps( a02v.v, setPlus.v ); + a03v.v = _mm256_mul_ps( a03v.v, setPlus.v ); + a04v.v = _mm256_mul_ps( a04v.v, setPlus.v ); + + a05v.v = _mm256_mul_ps( a00v.v, setMinus.v ); + a06v.v = _mm256_mul_ps( a01v.v, setMinus.v ); + a07v.v = _mm256_mul_ps( a02v.v, setMinus.v ); + a08v.v = _mm256_mul_ps( a03v.v, setMinus.v ); + a09v.v = _mm256_mul_ps( a04v.v, setMinus.v ); + + a05v.v = _mm256_permute_ps( a05v.v, 0xB1 ); + a06v.v = _mm256_permute_ps( a06v.v, 0xB1 ); + a07v.v = _mm256_permute_ps( a07v.v, 0xB1 ); + a08v.v = _mm256_permute_ps( a08v.v, 0xB1 ); + a09v.v = _mm256_permute_ps( a09v.v, 0xB1 ); + + a10v.v = _mm256_mul_ps( a10v.v, setPlus.v ); + a11v.v = _mm256_mul_ps( a11v.v, setPlus.v ); + a12v.v = _mm256_mul_ps( a12v.v, setPlus.v ); + a13v.v = _mm256_mul_ps( a13v.v, setPlus.v ); + a14v.v = _mm256_mul_ps( a14v.v, setPlus.v ); + + a15v.v = _mm256_mul_ps( a10v.v, setMinus.v ); + a16v.v = _mm256_mul_ps( a11v.v, setMinus.v ); + a17v.v = _mm256_mul_ps( a12v.v, setMinus.v ); + a18v.v = _mm256_mul_ps( a13v.v, setMinus.v ); + a19v.v = _mm256_mul_ps( a14v.v, setMinus.v ); + + a15v.v = _mm256_permute_ps( a15v.v, 0xB1 ); + a16v.v = _mm256_permute_ps( a16v.v, 0xB1 ); + a17v.v = _mm256_permute_ps( a17v.v, 0xB1 ); + a18v.v = _mm256_permute_ps( a18v.v, 0xB1 ); + a19v.v = _mm256_permute_ps( a19v.v, 0xB1 ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_ps( a00v.v, chi0v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a01v.v, chi1v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a02v.v, chi2v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a03v.v, chi3v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a04v.v, chi4v.v, y0v.v ); + + y0v.v = _mm256_fmadd_ps( a05v.v, chi5v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a06v.v, chi6v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a07v.v, chi7v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a08v.v, chi8v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a09v.v, chi9v.v, y0v.v ); + + // For next 4 elements perform : y += alpha * x; + y1v.v = _mm256_fmadd_ps( a10v.v, chi0v.v, y1v.v ); + y1v.v = _mm256_fmadd_ps( a11v.v, chi1v.v, y1v.v ); + y1v.v = _mm256_fmadd_ps( a12v.v, chi2v.v, y1v.v ); + y1v.v = _mm256_fmadd_ps( a13v.v, chi3v.v, y1v.v ); + y1v.v = _mm256_fmadd_ps( a14v.v, chi4v.v, y1v.v ); + + y1v.v = _mm256_fmadd_ps( a15v.v, chi5v.v, y1v.v ); + y1v.v = _mm256_fmadd_ps( a16v.v, chi6v.v, y1v.v ); + y1v.v = _mm256_fmadd_ps( a17v.v, chi7v.v, y1v.v ); + y1v.v = _mm256_fmadd_ps( a18v.v, chi8v.v, y1v.v ); + y1v.v = _mm256_fmadd_ps( a19v.v, chi9v.v, y1v.v ); + + // Store the output. + _mm256_storeu_ps( (float *)(y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_ps( (float *)(y0 + 1*n_elem_per_reg), y1v.v ); + + y0 += n_elem_per_reg * n_iter_unroll; + a0 += n_elem_per_reg * n_iter_unroll; + a1 += n_elem_per_reg * n_iter_unroll; + a2 += n_elem_per_reg * n_iter_unroll; + a3 += n_elem_per_reg * n_iter_unroll; + a4 += n_elem_per_reg * n_iter_unroll; + } +#endif + for( ; (i + 3) < m; i += 4 ) + { + // Load the input values. + y0v.v = _mm256_loadu_ps( (float*) (y0 + 0*n_elem_per_reg )); + + a00v.v = _mm256_loadu_ps( (float*) (a0 + 0*n_elem_per_reg )); + a01v.v = _mm256_loadu_ps( (float*) (a1 + 0*n_elem_per_reg )); + a02v.v = _mm256_loadu_ps( (float*) (a2 + 0*n_elem_per_reg )); + a03v.v = _mm256_loadu_ps( (float*) (a3 + 0*n_elem_per_reg )); + a04v.v = _mm256_loadu_ps( (float*) (a4 + 0*n_elem_per_reg )); + + a00v.v = _mm256_mul_ps( a00v.v, setPlus.v ); + a01v.v = _mm256_mul_ps( a01v.v, setPlus.v ); + a02v.v = _mm256_mul_ps( a02v.v, setPlus.v ); + a03v.v = _mm256_mul_ps( a03v.v, setPlus.v ); + a04v.v = _mm256_mul_ps( a04v.v, setPlus.v ); + + a05v.v = _mm256_mul_ps( a00v.v, setMinus.v ); + a06v.v = _mm256_mul_ps( a01v.v, setMinus.v ); + a07v.v = _mm256_mul_ps( a02v.v, setMinus.v ); + a08v.v = _mm256_mul_ps( a03v.v, setMinus.v ); + a09v.v = _mm256_mul_ps( a04v.v, setMinus.v ); + + a05v.v = _mm256_permute_ps( a05v.v, 0xB1 ); + a06v.v = _mm256_permute_ps( a06v.v, 0xB1 ); + a07v.v = _mm256_permute_ps( a07v.v, 0xB1 ); + a08v.v = _mm256_permute_ps( a08v.v, 0xB1 ); + a09v.v = _mm256_permute_ps( a09v.v, 0xB1 ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_ps( a00v.v, chi0v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a01v.v, chi1v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a02v.v, chi2v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a03v.v, chi3v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a04v.v, chi4v.v, y0v.v ); + + y0v.v = _mm256_fmadd_ps( a05v.v, chi5v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a06v.v, chi6v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a07v.v, chi7v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a08v.v, chi8v.v, y0v.v ); + y0v.v = _mm256_fmadd_ps( a09v.v, chi9v.v, y0v.v ); + + // Store the output. + _mm256_storeu_ps( (float *)(y0 + 0*n_elem_per_reg), y0v.v ); + + y0 += n_elem_per_reg ; + a0 += n_elem_per_reg ; + a1 += n_elem_per_reg ; + a2 += n_elem_per_reg ; + a3 += n_elem_per_reg ; + a4 += n_elem_per_reg ; + } + + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + scomplex y0c = *y0; + + const scomplex a0c = *a0; + const scomplex a1c = *a1; + const scomplex a2c = *a2; + const scomplex a3c = *a3; + const scomplex a4c = *a4; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + a2 += 1; + a3 += 1; + a4 += 1; + y0 += 1; + } + + } + else + { + for ( ; (i + 0) < m ; ++i ) + { + scomplex y0c = *y0; + const scomplex a0c = *a0; + const scomplex a1c = *a1; + const scomplex a2c = *a2; + const scomplex a3c = *a3; + const scomplex a4c = *a4; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + y0 += incy; + } + } +} + + +// ----------------------------------------------------------------------------- + +void bli_zaxpyf_zen_int_5 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + const dim_t fuse_fac = 5; + + const dim_t n_elem_per_reg = 2; + const dim_t n_iter_unroll = 2; + + dim_t i = 0; + dim_t setPlusOne = 1; + + v4df_t chi0v, chi1v, chi2v, chi3v, chi4v; + v4df_t chi5v, chi6v, chi7v, chi8v, chi9v; + + v4df_t a00v, a01v, a02v, a03v, a04v; + v4df_t a05v, a06v, a07v, a08v, a09v; + + v4df_t a10v, a11v, a12v, a13v, a14v; + v4df_t a15v, a16v, a17v, a18v, a19v; + + v4df_t y0v, y1v; + v4df_t setMinus, setPlus; + + dcomplex chi0, chi1, chi2, chi3, chi4; + dcomplex* restrict a0; + dcomplex* restrict a1; + dcomplex* restrict a2; + dcomplex* restrict a3; + dcomplex* restrict a4; + + dcomplex* restrict y0; + + + if ( bli_is_conj(conja) ){ + setPlusOne = -1; + } + + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { +#ifdef BLIS_CONFIG_EPYC + for ( i = 0; i < b_n; ++i ) + { + dcomplex* a1 = a + (0 )*inca + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y + (0 )*incy; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + bli_zaxpyv_zen_int5 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#else + zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + dcomplex* a1 = a + (0 )*inca + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y + (0 )*incy; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } + +#endif + return; + } + + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + a2 = a + 2*lda; + a3 = a + 3*lda; + a4 = a + 4*lda; + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + chi4 = *( x + 4*incx ); + + dcomplex *pchi0 = x + 0*incx ; + dcomplex *pchi1 = x + 1*incx ; + dcomplex *pchi2 = x + 2*incx ; + dcomplex *pchi3 = x + 3*incx ; + dcomplex *pchi4 = x + 4*incx ; + + bli_zcopycjs( conjx, *pchi0, chi0 ); + bli_zcopycjs( conjx, *pchi1, chi1 ); + bli_zcopycjs( conjx, *pchi2, chi2 ); + bli_zcopycjs( conjx, *pchi3, chi3 ); + bli_zcopycjs( conjx, *pchi4, chi4 ); + + // Scale each chi scalar by alpha. + bli_zscals( *alpha, chi0 ); + bli_zscals( *alpha, chi1 ); + bli_zscals( *alpha, chi2 ); + bli_zscals( *alpha, chi3 ); + bli_zscals( *alpha, chi4 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_sd( &chi0.real ); + chi1v.v = _mm256_broadcast_sd( &chi1.real ); + chi2v.v = _mm256_broadcast_sd( &chi2.real ); + chi3v.v = _mm256_broadcast_sd( &chi3.real ); + chi4v.v = _mm256_broadcast_sd( &chi4.real ); + + chi5v.v = _mm256_broadcast_sd( &chi0.imag ); + chi6v.v = _mm256_broadcast_sd( &chi1.imag ); + chi7v.v = _mm256_broadcast_sd( &chi2.imag ); + chi8v.v = _mm256_broadcast_sd( &chi3.imag ); + chi9v.v = _mm256_broadcast_sd( &chi4.imag ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + setMinus.v = _mm256_set_pd( -1, 1, -1, 1 ); + + setPlus.v = _mm256_set1_pd( 1 ); + if ( bli_is_conj(conja) ){ + setPlus.v = _mm256_set_pd( -1, 1, -1, 1 ); + } + + /* + y := y + alpha * conja(A) * conjx(x) + + nn + (ar + ai) (xr + xi) + ar * xr - ai * xi + ar * xi + ai * xr + + cc : (ar - ai) (xr - xi) + ar * xr - ai * xi + -(ar * xi + ai * xr) + + nc : (ar + ai) (xr - xi) + ar * xr + ai * xi + -(ar * xi - ai * xr) + + cn : (ar - ai) (xr + xi) + ar * xr + ai * xi + ar * xi - ai * xr + + */ + + for( i = 0; (i + 3) < m; i += 4 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); + y1v.v = _mm256_loadu_pd( (double*) (y0 + 1*n_elem_per_reg )); + + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); + a10v.v = _mm256_loadu_pd( (double*) (a0 + 1*n_elem_per_reg )); + + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); + a11v.v = _mm256_loadu_pd( (double*) (a1 + 1*n_elem_per_reg )); + + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); + a12v.v = _mm256_loadu_pd( (double*) (a2 + 1*n_elem_per_reg )); + + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); + a13v.v = _mm256_loadu_pd( (double*) (a3 + 1*n_elem_per_reg )); + + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); + a14v.v = _mm256_loadu_pd( (double*) (a4 + 1*n_elem_per_reg )); + + a00v.v = _mm256_mul_pd( a00v.v, setPlus.v ); + a01v.v = _mm256_mul_pd( a01v.v, setPlus.v ); + a02v.v = _mm256_mul_pd( a02v.v, setPlus.v ); + a03v.v = _mm256_mul_pd( a03v.v, setPlus.v ); + a04v.v = _mm256_mul_pd( a04v.v, setPlus.v ); + + a05v.v = _mm256_mul_pd( a00v.v, setMinus.v ); + a06v.v = _mm256_mul_pd( a01v.v, setMinus.v ); + a07v.v = _mm256_mul_pd( a02v.v, setMinus.v ); + a08v.v = _mm256_mul_pd( a03v.v, setMinus.v ); + a09v.v = _mm256_mul_pd( a04v.v, setMinus.v ); + + a05v.v = _mm256_permute_pd( a05v.v, 5 ); + a06v.v = _mm256_permute_pd( a06v.v, 5 ); + a07v.v = _mm256_permute_pd( a07v.v, 5 ); + a08v.v = _mm256_permute_pd( a08v.v, 5 ); + a09v.v = _mm256_permute_pd( a09v.v, 5 ); + + a10v.v = _mm256_mul_pd( a10v.v, setPlus.v ); + a11v.v = _mm256_mul_pd( a11v.v, setPlus.v ); + a12v.v = _mm256_mul_pd( a12v.v, setPlus.v ); + a13v.v = _mm256_mul_pd( a13v.v, setPlus.v ); + a14v.v = _mm256_mul_pd( a14v.v, setPlus.v ); + + a15v.v = _mm256_mul_pd( a10v.v, setMinus.v ); + a16v.v = _mm256_mul_pd( a11v.v, setMinus.v ); + a17v.v = _mm256_mul_pd( a12v.v, setMinus.v ); + a18v.v = _mm256_mul_pd( a13v.v, setMinus.v ); + a19v.v = _mm256_mul_pd( a14v.v, setMinus.v ); + + a15v.v = _mm256_permute_pd( a15v.v, 5 ); + a16v.v = _mm256_permute_pd( a16v.v, 5 ); + a17v.v = _mm256_permute_pd( a17v.v, 5 ); + a18v.v = _mm256_permute_pd( a18v.v, 5 ); + a19v.v = _mm256_permute_pd( a19v.v, 5 ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); + + y0v.v = _mm256_fmadd_pd( a05v.v, chi5v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a06v.v, chi6v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a07v.v, chi7v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a08v.v, chi8v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a09v.v, chi9v.v, y0v.v ); + + // For next 4 elements perform : y += alpha * x; + y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + y1v.v = _mm256_fmadd_pd( a14v.v, chi4v.v, y1v.v ); + + y1v.v = _mm256_fmadd_pd( a15v.v, chi5v.v, y1v.v ); + y1v.v = _mm256_fmadd_pd( a16v.v, chi6v.v, y1v.v ); + y1v.v = _mm256_fmadd_pd( a17v.v, chi7v.v, y1v.v ); + y1v.v = _mm256_fmadd_pd( a18v.v, chi8v.v, y1v.v ); + y1v.v = _mm256_fmadd_pd( a19v.v, chi9v.v, y1v.v ); + + // Store the output. + _mm256_storeu_pd( (double*) (y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (double*) (y0 + 1*n_elem_per_reg), y1v.v ); + + y0 += n_elem_per_reg * n_iter_unroll; + a0 += n_elem_per_reg * n_iter_unroll; + a1 += n_elem_per_reg * n_iter_unroll; + a2 += n_elem_per_reg * n_iter_unroll; + a3 += n_elem_per_reg * n_iter_unroll; + a4 += n_elem_per_reg * n_iter_unroll; + } + for( ; (i + 1) < m; i += 2 ) + { + // Load the input values. + y0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); + + a00v.v = _mm256_loadu_pd( (double*)(a0 + 0*n_elem_per_reg) ); + a01v.v = _mm256_loadu_pd( (double*)(a1 + 0*n_elem_per_reg) ); + a02v.v = _mm256_loadu_pd( (double*)(a2 + 0*n_elem_per_reg) ); + a03v.v = _mm256_loadu_pd( (double*)(a3 + 0*n_elem_per_reg) ); + a04v.v = _mm256_loadu_pd( (double*)(a4 + 0*n_elem_per_reg) ); + + a00v.v = _mm256_mul_pd( a00v.v, setPlus.v ); + a01v.v = _mm256_mul_pd( a01v.v, setPlus.v ); + a02v.v = _mm256_mul_pd( a02v.v, setPlus.v ); + a03v.v = _mm256_mul_pd( a03v.v, setPlus.v ); + a04v.v = _mm256_mul_pd( a04v.v, setPlus.v ); + + a05v.v = _mm256_mul_pd( a00v.v, setMinus.v ); + a06v.v = _mm256_mul_pd( a01v.v, setMinus.v ); + a07v.v = _mm256_mul_pd( a02v.v, setMinus.v ); + a08v.v = _mm256_mul_pd( a03v.v, setMinus.v ); + a09v.v = _mm256_mul_pd( a04v.v, setMinus.v ); + + a05v.v = _mm256_permute_pd( a05v.v, 5 ); + a06v.v = _mm256_permute_pd( a06v.v, 5 ); + a07v.v = _mm256_permute_pd( a07v.v, 5 ); + a08v.v = _mm256_permute_pd( a08v.v, 5 ); + a09v.v = _mm256_permute_pd( a09v.v, 5 ); + + // perform : y += alpha * x; + y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); + + y0v.v = _mm256_fmadd_pd( a05v.v, chi5v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a06v.v, chi6v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a07v.v, chi7v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a08v.v, chi8v.v, y0v.v ); + y0v.v = _mm256_fmadd_pd( a09v.v, chi9v.v, y0v.v ); + + // Store the output. + _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + + y0 += n_elem_per_reg ; + a0 += n_elem_per_reg ; + a1 += n_elem_per_reg ; + a2 += n_elem_per_reg ; + a3 += n_elem_per_reg ; + a4 += n_elem_per_reg ; + } + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + dcomplex y0c = *y0; + + const dcomplex a0c = *a0; + const dcomplex a1c = *a1; + const dcomplex a2c = *a2; + const dcomplex a3c = *a3; + const dcomplex a4c = *a4; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + a2 += 1; + a3 += 1; + a4 += 1; + y0 += 1; + } + } + else + { + for ( ; (i + 0) < m ; ++i ) + { + dcomplex y0c = *y0; + + const dcomplex a0c = *a0; + const dcomplex a1c = *a1; + const dcomplex a2c = *a2; + const dcomplex a3c = *a3; + const dcomplex a4c = *a4; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; y0 += incy; } diff --git a/test/test_gemv.c b/test/test_gemv.c index 4cc60eefa..b78233def 100644 --- a/test/test_gemv.c +++ b/test/test_gemv.c @@ -124,7 +124,7 @@ int main( int argc, char** argv ) bli_copym( &y, &y_save ); - + dtime_save = DBL_MAX; for ( r = 0; r < n_repeats; ++r ) @@ -149,26 +149,151 @@ int main( int argc, char** argv ) &y ); #else - f77_char transa = 'N'; - f77_int mm = bli_obj_length( &a ); - f77_int nn = bli_obj_width( &a ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int incx = bli_obj_vector_inc( &x ); - f77_int incy = bli_obj_vector_inc( &y ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* xp = bli_obj_buffer( &x ); - double* betap = bli_obj_buffer( &beta ); - double* yp = bli_obj_buffer( &y ); - - dgemv_( &transa, - &mm, - &nn, - alphap, - ap, &lda, - xp, &incx, - betap, - yp, &incy ); +#ifdef CBLAS + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + + if ( bli_obj_row_stride( &a ) == 1 ) + cblas_order = CblasColMajor; + else + cblas_order = CblasRowMajor; + + cblas_transa = CblasNoTrans; +#else + f77_char transa = 'N'; +#endif + + if ( bli_is_float( dt ) ){ + f77_int mm = bli_obj_length( &a ); + f77_int nn = bli_obj_width( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* xp = bli_obj_buffer( &x ); + float* betap = bli_obj_buffer( &beta ); + float* yp = bli_obj_buffer( &y ); +#ifdef CBLAS + cblas_sgemv( cblas_order, + cblas_transa, + mm, + nn, + *alphap, + ap, lda, + xp, incx, + *betap, + yp, incy ); +#else + sgemv_( &transa, + &mm, + &nn, + alphap, + ap, &lda, + xp, &incx, + betap, + yp, &incy ); +#endif + } + else if ( bli_is_double( dt ) ) + { + f77_int mm = bli_obj_length( &a ); + f77_int nn = bli_obj_width( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* xp = bli_obj_buffer( &x ); + double* betap = bli_obj_buffer( &beta ); + double* yp = bli_obj_buffer( &y ); +#ifdef CBLAS + cblas_dgemv( cblas_order, + cblas_transa, + mm, + nn, + *alphap, + ap, lda, + xp, incx, + *betap, + yp, incy ); +#else + dgemv_( &transa, + &mm, + &nn, + alphap, + ap, &lda, + xp, &incx, + betap, + yp, &incy ); +#endif + } + else if ( bli_is_scomplex( dt ) ) + { + f77_int mm = bli_obj_length( &a ); + f77_int nn = bli_obj_width( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* xp = bli_obj_buffer( &x ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* yp = bli_obj_buffer( &y ); +#ifdef CBLAS + cblas_cgemv( cblas_order, + cblas_transa, + mm, + nn, + alphap, + ap, lda, + xp, incx, + betap, + yp, incy ); +#else + cgemv_( &transa, + &mm, + &nn, + alphap, + ap, &lda, + xp, &incx, + betap, + yp, &incy ); +#endif + } + else if ( bli_is_dcomplex( dt ) ) + { + f77_int mm = bli_obj_length( &a ); + f77_int nn = bli_obj_width( &a ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int incx = bli_obj_vector_inc( &x ); + f77_int incy = bli_obj_vector_inc( &y ); + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* xp = bli_obj_buffer( &x ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* yp = bli_obj_buffer( &y ); +#ifdef CBLAS + cblas_zgemv( cblas_order, + cblas_transa, + mm, + nn, + alphap, + ap, lda, + xp, incx, + betap, + yp, incy ); +#else + zgemv_( &transa, + &mm, + &nn, + alphap, + ap, &lda, + xp, &incx, + betap, + yp, &incy ); +#endif + } #endif #ifdef PRINT