Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow compilation for transposed convolution #5541

Open
hawkinsp opened this issue Sep 11, 2023 · 0 comments
Open

Slow compilation for transposed convolution #5541

hawkinsp opened this issue Sep 11, 2023 · 0 comments
Assignees
Labels
GPU XLA on GPU NVIDIA-GPU XLA on Nvidia GPU

Comments

@hawkinsp
Copy link
Member

hawkinsp commented Sep 11, 2023

The following HLO module consists of a single convolution and takes a very long time to compile (around 30+ seconds on my workstation):

HloModule jit_toy_model_jax, entry_computation_layout={(f32[1,8,8,128]{3,2,1,0}, f32[32,32,128,32]{3,2,1,0})->f32[1,144,144,32]{3,2,1,0}}

ENTRY main.4 {
  Arg_0.1 = f32[1,8,8,128]{3,2,1,0} parameter(0), sharding={replicated}
  Arg_1.2 = f32[32,32,128,32]{3,2,1,0} parameter(1), sharding={replicated}
  ROOT convolution.3 = f32[1,144,144,32]{3,2,1,0} convolution(Arg_0.1, Arg_1.2), window={size=32x32 pad=31_31x31_31 lhs_dilate=16x16}, dim_labels=b01f_01io->b01f
}

This bug was originally reported as jax-ml/jax#17464, in which the user reports their model took over 4 minutes to compile.

Something seems wrong here...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
GPU XLA on GPU NVIDIA-GPU XLA on Nvidia GPU
Projects
None yet
Development

No branches or pull requests

2 participants