Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some aspects of the control tre/plugin infrastructure #827

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions build/plugin/bli_plugin.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ INSERT_GENTCONF
#undef GENTCONF
#define GENTCONF( CONFIG, config ) \
\
void PASTEMAC(plugin_init,BLIS_PNAME_INFIX,_,config)( PASTECH(plugin,BLIS_PNAME_INFIX,_params) ); \
void PASTEMAC(plugin_init,BLIS_PNAME_INFIX,_,config,BLIS_REF_SUFFIX)( PASTECH(plugin,BLIS_PNAME_INFIX,_params) );
void PASTEMAC(plugin_init_@plugin_name@_,config)( plugin_@plugin_name@_params ); \
void PASTEMAC(plugin_init_@plugin_name@_,config,BLIS_REF_SUFFIX)( plugin_@plugin_name@_params );

INSERT_GENTCONF

BLIS_EXPORT_BLIS err_t PASTEMAC(plugin_register,BLIS_PNAME_INFIX)( PASTECH(plugin,BLIS_PNAME_INFIX,_params) );
BLIS_EXPORT_BLIS err_t bli_plugin_register_@plugin_name@( plugin_@plugin_name@_params );

4 changes: 2 additions & 2 deletions build/plugin/bli_plugin_init_ref.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ do { \

// -----------------------------------------------------------------------------

void PASTEMAC(plugin_init,BLIS_PNAME_INFIX,BLIS_CNAME_INFIX,BLIS_REF_SUFFIX)
void PASTEMAC(plugin_init_@plugin_name@,BLIS_CNAME_INFIX,BLIS_REF_SUFFIX)
(
PASTECH(plugin,BLIS_PNAME_INFIX,_params)
plugin_@plugin_name@_params
)
{
cntx_t* cntx = ( cntx_t* )bli_gks_lookup_id( PASTECH(BLIS_ARCH,BLIS_CNAME_UPPER_INFIX) );
Expand Down
9 changes: 7 additions & 2 deletions build/plugin/bli_plugin_init_zen3.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,19 @@

#include @PLUGIN_HEADER@

void PASTEMAC(plugin_init,BLIS_PNAME_INFIX,BLIS_CNAME_INFIX)
void PASTEMAC(plugin_init_@plugin_name@,BLIS_CNAME_INFIX)
(
PASTECH(plugin,BLIS_PNAME_INFIX,_params)
plugin_@plugin_name@_params
)
{
cntx_t* cntx = ( cntx_t* )bli_gks_lookup_id( PASTECH(BLIS_ARCH,BLIS_CNAME_UPPER_INFIX) );
( void )cntx;

PASTEMAC(plugin_init_@plugin_name@,BLIS_CNAME_INFIX,BLIS_REF_SUFFIX)
(
plugin_@plugin_name@_params_only
);

// ------------------------------------------------------------------------>
// -- Example Initialization ---------------------------------------------->
// ------------------------------------------------------------------------>
Expand Down
8 changes: 4 additions & 4 deletions build/plugin/bli_plugin_register.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@

#include @PLUGIN_HEADER@

err_t PASTEMAC(plugin_register,BLIS_PNAME_INFIX)
err_t bli_plugin_register_@plugin_name@
(
PASTECH(plugin,BLIS_PNAME_INFIX,_params)
plugin_@plugin_name@_params
)
{
// ------------------------------------------------------------------------>
Expand Down Expand Up @@ -69,9 +69,9 @@ err_t PASTEMAC(plugin_register,BLIS_PNAME_INFIX)

#undef GENTCONF
#define GENTCONF( CONFIG, config ) \
PASTEMAC(plugin_init,BLIS_PNAME_INFIX,_,config,BLIS_REF_SUFFIX) \
PASTEMAC(plugin_init_@plugin_name@_,config) \
( \
PASTECH(plugin,BLIS_PNAME_INFIX,_params_only) \
plugin_@plugin_name@_params_only \
);

INSERT_GENTCONF
Expand Down
8 changes: 4 additions & 4 deletions configure
Original file line number Diff line number Diff line change
Expand Up @@ -5243,7 +5243,7 @@ plugin_main()
else
strip_examples ${sharedir}/blis/plugin/bli_plugin_register.c bli_plugin_register.c
fi
perl -pi -e "s|\@PLUGIN_HEADER\@|${plugin_h}|" bli_plugin_register.c
perl -pi -e "s|\@PLUGIN_HEADER\@|${plugin_h}|;" -e "s|\@plugin_name\@|${plugin_name}|;" bli_plugin_register.c
maybe_echo "done"
fi

Expand Down Expand Up @@ -5279,7 +5279,7 @@ plugin_main()
else
strip_examples ${sharedir}/blis/plugin/${file} ref_kernels/${file}
fi
perl -pi -e "s|\@PLUGIN_HEADER\@|${plugin_h}|" ref_kernels/${file}
perl -pi -e "s|\@PLUGIN_HEADER\@|${plugin_h}|;" -e "s|\@plugin_name\@|${plugin_name}|;" ref_kernels/${file}
done="false"
fi
done
Expand Down Expand Up @@ -5314,7 +5314,7 @@ plugin_main()
else
cp ${sharedir}/blis/plugin/bli_plugin_init_zen3.c config/${config}/bli_plugin_init_${config}.c
fi
perl -pi -e "s|\@PLUGIN_HEADER\@|${plugin_h}|" config/${config}/bli_plugin_init_${config}.c
perl -pi -e "s|\@PLUGIN_HEADER\@|${plugin_h}|;" -e "s|\@plugin_name\@|${plugin_name}|;" config/${config}/bli_plugin_init_${config}.c
fi

if [ ! -e config/${config}/bli_kernel_defs_${config}.h ] || [ ${force_flag} == '1' ]; then
Expand All @@ -5340,7 +5340,7 @@ plugin_main()
if [ ${examples_flag} == '1' ]; then
if [ ! -e kernels/zen3/my_kernel_1_zen3.c ] || [ ${force_flag} == '1' ]; then
cp ${sharedir}/blis/plugin/my_kernel_1_zen3.c kernels/zen3
perl -pi -e "s|\@PLUGIN_HEADER\@|${plugin_h}|" kernels/zen3/my_kernel_1_zen3.c
perl -pi -e "s|\@PLUGIN_HEADER\@|${plugin_h}|;" -e "s|\@plugin_name\@|${plugin_name}|;" kernels/zen3/my_kernel_1_zen3.c
fi
fi

Expand Down
4 changes: 3 additions & 1 deletion frame/3/gemm/bli_gemm_cntl.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void bli_gemm_var_cntl_init_node
);
}

void bli_gemm_cntl_init
bool bli_gemm_cntl_init
(
ind_t im,
opid_t family,
Expand Down Expand Up @@ -559,6 +559,8 @@ void bli_gemm_cntl_init
c,
cntl
);

return needs_swap;
}

void bli_gemm_cntl_finalize
Expand Down
2 changes: 1 addition & 1 deletion frame/3/gemm/bli_gemm_cntl.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ typedef struct gemm_cntl_s gemm_cntl_t;

// -----------------------------------------------------------------------------

BLIS_EXPORT_BLIS void bli_gemm_cntl_init
BLIS_EXPORT_BLIS bool bli_gemm_cntl_init
(
ind_t im,
opid_t family,
Expand Down
71 changes: 33 additions & 38 deletions frame/base/bli_cntx.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ BLIS_EXPORT_BLIS err_t bli_cntx_init( cntx_t* cntx )
if ( error != BLIS_SUCCESS )
return error;

error = bli_stack_init( sizeof( bszid_t ), 32, 32, BLIS_NUM_BLKSZS, &cntx->bmults );
error = bli_stack_init( sizeof( siz_t ), 32, 32, BLIS_NUM_BLKSZS, &cntx->bmults );
if ( error != BLIS_SUCCESS )
return error;

Expand Down Expand Up @@ -118,9 +118,9 @@ void bli_cntx_set_blkszs( cntx_t* cntx, ... )
void bli_cntx_set_blkszs
(
cntx_t* cntx,
bszid_t bs0_id, blksz_t* blksz0, bszid_t bm0_id,
bszid_t bs1_id, blksz_t* blksz1, bszid_t bm1_id,
bszid_t bs2_id, blksz_t* blksz2, bszid_t bm2_id,
siz_t bs0_id, blksz_t* blksz0, siz_t bm0_id,
siz_t bs1_id, blksz_t* blksz1, siz_t bm1_id,
siz_t bs2_id, blksz_t* blksz2, siz_t bm2_id,
...,
BLIS_VA_END
);
Expand All @@ -133,19 +133,18 @@ void bli_cntx_set_blkszs( cntx_t* cntx, ... )
// Process blocksizes until we get a BLIS_VA_END.
while ( true )
{
int bs_id0 = va_arg( args, int );
int bs_id = va_arg( args, siz_t );

// If we find a bszid_t id of BLIS_VA_END, then we are done.
if ( bs_id0 == BLIS_VA_END ) break;
// If we find a siz_t id of BLIS_VA_END, then we are done.
if ( bs_id == BLIS_VA_END ) break;

// Here, we query the variable argument list for:
// - the bszid_t of the blocksize we're about to process (already done),
// - the siz_t of the blocksize we're about to process (already done),
// - the address of the blksz_t object,
// - the bszid_t of the multiple we need to associate with
// - the siz_t of the multiple we need to associate with
// the blksz_t object.
bszid_t bs_id = ( bszid_t )bs_id0;
blksz_t* blksz = ( blksz_t* )va_arg( args, blksz_t* );
bszid_t bm_id = ( bszid_t )va_arg( args, bszid_t );
siz_t bm_id = ( siz_t )va_arg( args, siz_t );

// Copy the blksz_t object contents into the appropriate
// location within the context's blksz_t array. Do the same
Expand All @@ -172,9 +171,9 @@ void bli_cntx_set_ukrs( cntx_t* cntx , ... )
void bli_cntx_set_ukrs
(
cntx_t* cntx,
ukr_t ukr0_id, num_t dt0, void_fp ukr0_fp,
ukr_t ukr1_id, num_t dt1, void_fp ukr1_fp,
ukr_t ukr2_id, num_t dt2, void_fp ukr2_fp,
siz_t ukr0_id, num_t dt0, void_fp ukr0_fp,
siz_t ukr1_id, num_t dt1, void_fp ukr1_fp,
siz_t ukr2_id, num_t dt2, void_fp ukr2_fp,
...,
BLIS_VA_END
);
Expand All @@ -187,16 +186,15 @@ void bli_cntx_set_ukrs( cntx_t* cntx , ... )
// Process ukernels until BLIS_VA_END is reached.
while ( true )
{
const int ukr_id0 = va_arg( args, int );
const int ukr_id = va_arg( args, siz_t );

// If we find a ukernel id of BLIS_VA_END, then we are done.
if ( ukr_id0 == BLIS_VA_END ) break;
if ( ukr_id == BLIS_VA_END ) break;

// Here, we query the variable argument list for:
// - the ukr_t of the kernel we're about to process (already done),
// - the siz_t of the kernel we're about to process (already done),
// - the datatype of the kernel, and
// - the kernel function pointer
const ukr_t ukr_id = ( ukr_t )ukr_id0;
const num_t ukr_dt = ( num_t )va_arg( args, num_t );
void_fp ukr_fp = ( void_fp )va_arg( args, void_fp );

Expand All @@ -223,9 +221,9 @@ void bli_cntx_set_ukr2s( cntx_t* cntx , ... )
void bli_cntx_set_ukr2s
(
cntx_t* cntx,
ukr_t ukr0_id, num_t dt1_0, num_t dt2_0, void_fp ukr0_fp,
ukr_t ukr1_id, num_t dt1_1, num_t dt2_1, void_fp ukr1_fp,
ukr_t ukr2_id, num_t dt1_2, num_t dt2_2, void_fp ukr2_fp,
siz_t ukr0_id, num_t dt1_0, num_t dt2_0, void_fp ukr0_fp,
siz_t ukr1_id, num_t dt1_1, num_t dt2_1, void_fp ukr1_fp,
siz_t ukr2_id, num_t dt1_2, num_t dt2_2, void_fp ukr2_fp,
...,
BLIS_VA_END
);
Expand All @@ -238,16 +236,15 @@ void bli_cntx_set_ukr2s( cntx_t* cntx , ... )
// Process ukernels until BLIS_VA_END is reached.
while ( true )
{
const int ukr_id0 = va_arg( args, int );
const int ukr_id = va_arg( args, siz_t );

// If we find a ukernel id of BLIS_VA_END, then we are done.
if ( ukr_id0 == BLIS_VA_END ) break;
if ( ukr_id == BLIS_VA_END ) break;

// Here, we query the variable argument list for:
// - the ukr_t of the kernel we're about to process (already done),
// - the siz_t of the kernel we're about to process (already done),
// - the datatype of the kernel, and
// - the kernel function pointer
const ukr_t ukr_id = ( ukr_t )ukr_id0;
const num_t ukr_dt1 = ( num_t )va_arg( args, num_t );
const num_t ukr_dt2 = ( num_t )va_arg( args, num_t );
void_fp ukr_fp = ( void_fp )va_arg( args, void_fp );
Expand Down Expand Up @@ -275,9 +272,9 @@ void bli_cntx_set_ukr_prefs( cntx_t* cntx , ... )
void bli_cntx_set_ukr_prefs
(
cntx_t* cntx,
ukr_pref_t ukr_pref0_id, num_t dt0, bool ukr_pref0,
ukr_pref_t ukr_pref1_id, num_t dt1, bool ukr_pref1,
ukr_pref_t ukr_pref2_id, num_t dt2, bool ukr_pref2,
siz_t ukr_pref0_id, num_t dt0, bool ukr_pref0,
siz_t ukr_pref1_id, num_t dt1, bool ukr_pref1,
siz_t ukr_pref2_id, num_t dt2, bool ukr_pref2,
...,
BLIS_VA_END
);
Expand All @@ -290,18 +287,17 @@ void bli_cntx_set_ukr_prefs( cntx_t* cntx , ... )
// Process ukernel preferences until BLIS_VA_END is reached.
while ( true )
{
const int ukr_pref_id0 = va_arg( args, int );
const int ukr_pref_id = va_arg( args, siz_t );

// If we find a ukernel pref id of BLIS_VA_END, then we are done.
if ( ukr_pref_id0 == BLIS_VA_END ) break;
if ( ukr_pref_id == BLIS_VA_END ) break;

// Here, we query the variable argument list for:
// - the ukr_t of the kernel we're about to process (already done),
// - the siz_t of the kernel we're about to process (already done),
// - the datatype of the kernel, and
// - the kernel function pointer
const ukr_pref_t ukr_pref_id = ( ukr_pref_t )ukr_pref_id0;
const num_t ukr_pref_dt = ( num_t )va_arg( args, num_t );
const bool ukr_pref = ( bool )va_arg( args, int );
const num_t ukr_pref_dt = ( num_t )va_arg( args, num_t );
const bool ukr_pref = ( bool )va_arg( args, int );

// Store the ukernel preference value into the context.
bli_cntx_set_ukr_pref_dt( ukr_pref, ukr_pref_dt, ukr_pref_id, cntx );
Expand Down Expand Up @@ -341,15 +337,14 @@ void bli_cntx_set_l3_sup_handlers( cntx_t* cntx, ... )
// Process sup handlers until BLIS_VA_END is reached.
while ( true )
{
const int op_id0 = va_arg( args, int );
const opid_t op_id = va_arg( args, siz_t );

// If we find an operation id of BLIS_VA_END, then we are done.
if ( op_id0 == BLIS_VA_END ) break;
if ( op_id == BLIS_VA_END ) break;

// Here, we query the variable argument list for:
// - the opid_t of the operation we're about to process,
// - the sup handler function pointer
const opid_t op_id = ( opid_t )op_id0;
void_fp op_fp = ( void_fp )va_arg( args, void_fp );

if ( op_id >= BLIS_NUM_LEVEL3_OPS )
Expand All @@ -368,7 +363,7 @@ void bli_cntx_set_l3_sup_handlers( cntx_t* cntx, ... )

// -----------------------------------------------------------------------------

err_t bli_cntx_register_blksz( siz_t* bs_id, const blksz_t* blksz, bszid_t bmult_id, cntx_t* cntx )
err_t bli_cntx_register_blksz( siz_t* bs_id, const blksz_t* blksz, siz_t bmult_id, cntx_t* cntx )
{
siz_t id_blksz;
err_t error = bli_stack_push( &id_blksz, &cntx->blkszs );
Expand Down
Loading