Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.conv_transpose takes FOREVER to compile #17464

Open
sokrypton opened this issue Sep 6, 2023 · 3 comments
Open

lax.conv_transpose takes FOREVER to compile #17464

sokrypton opened this issue Sep 6, 2023 · 3 comments
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs XLA

Comments

@sokrypton
Copy link

sokrypton commented Sep 6, 2023

Description

I initially submitted the issue here:
google-deepmind/dm-haiku#724

But then realized it was a jax issue.

In short, I've been trying to use Conv2DTranspose in my model, and even for very simple case... it takes forever to compile.

import jax
from jax import lax, random
import jax.numpy as jnp
import time

# Directly implement the Conv2DTranspose in JAX
def toy_model_jax(x, params):
    return lax.conv_transpose(x, params["kernel"], strides=(16, 16), padding="VALID")

# Initialize parameters for the toy model
def initialize_params(key):
    kernel_shape = (32, 32, 128, 32)  # (height, width, in_channels, out_channels)
    kernel = random.normal(key, kernel_shape)
    return {"kernel": kernel}

# Generate random input and params
start_time = time.time()
key = random.PRNGKey(42)
x = random.normal(key, (1, 8, 8, 128))
params = initialize_params(key)
end_time = time.time()
print(f"Initialization Run Time: {end_time - start_time:.6f} seconds")

# JIT-compile and time the model run
toy_model_jax_jitted = jax.jit(toy_model_jax)

# Time the model compilation
start_time = time.time()
# Warm-up call (this compiles the function)
_ = toy_model_jax_jitted(x, params)
end_time = time.time()
print(f"JAX Compilation Time: {end_time - start_time:.6f} seconds")

# Time the model run
start_time = time.time()
o = toy_model_jax_jitted(x, params)
print("input_shape", x.shape)
print("output_shape", o.shape)
end_time = time.time()
print(f"JITted Run Time: {end_time - start_time:.6f} seconds")

output

Initialization Run Time: 2.540971 seconds
JAX Compilation Time: 251.976538 seconds
input_shape (1, 8, 8, 128)
output_shape (1, 144, 144, 32)
JITted Run Time: 0.001842 seconds

For comparison, here is the pytorch:

Initialization Time: 0.033582 seconds
input_shape torch.Size([1, 128, 8, 8])
output_shape torch.Size([1, 32, 144, 144])
Run Time: 0.047478 seconds

Google colab notebook:
https://colab.research.google.com/drive/15YkOuK0EjqZdBNaXpF2wpYexGqtjZjLr

What jax/jaxlib version are you using?

Google Colab

Which accelerator(s) are you using?

GPU

Additional system info

Google Colab

NVIDIA GPU info

Wed Sep  6 14:21:28 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   58C    P0    28W /  70W |  11957MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
@hawkinsp
Copy link
Collaborator

I think this in turn is an XLA bug. Opened openxla/xla#5541.

@hawkinsp hawkinsp added the XLA label Sep 11, 2023
@hawkinsp
Copy link
Collaborator

This is apparently due to convolution autotuning: some of the algorithms in cudnn are very slow and we try them all during autotuning. Once autotuning has run we will choose a fast algorithm.

@akuegel
Copy link
Contributor

akuegel commented Nov 20, 2023

It seems in this case the same algorithms are returned by heuristics_mode_a and heuristics_mode_b. So when we deduplicate the algorithms to try during autotuning, we can half the compile time. That still means it is slow, but it is a step in the right direction. There is an idea how to potentially speed it up more by stopping an autotuning attempt if the best known runtime is already exceeded, but that will take a bit longer to implement.

@hawkinsp hawkinsp added the NVIDIA GPU Issues specific to NVIDIA GPUs label Nov 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs XLA
Projects
None yet
Development

No branches or pull requests

3 participants