Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hk.Conv2DTranspose takes FOREVER to initialize and compile #724

Open
sokrypton opened this issue Sep 6, 2023 · 1 comment
Open

hk.Conv2DTranspose takes FOREVER to initialize and compile #724

sokrypton opened this issue Sep 6, 2023 · 1 comment

Comments

@sokrypton
Copy link

Not sure if this is a jax thing or dm-haiku... but recently I've been trying to use Conv2DTranspose in my model, and even for very simple case... it takes forever to compile.

here is an example:

import haiku as hk
import jax
from jax import random
import time

def toy_model(x):
  x = hk.Conv2DTranspose(32, 32, stride=16, padding="VALID")(x)
  return x

# Transform the model to be JAX-compatible
toy_model_init = hk.transform(toy_model).init
toy_model_apply = hk.transform(toy_model).apply

# Generate random input and params
key = random.PRNGKey(42)
x = random.normal(key, (1, 8, 8, 128))

# Time the model initialization
start_time = time.time()
params = toy_model_init(key, x)
end_time = time.time()
print(f"initialization Time: {end_time - start_time:.6f} seconds")

# Time the model compilation
start_time = time.time()
compiled_apply = jax.jit(toy_model_apply)
# Warm-up call (this compiles the function)
_ = compiled_apply(params, None, x)
end_time = time.time()
print(f"Compilation Time: {end_time - start_time:.6f} seconds")

# Time the model run
start_time = time.time()
o = compiled_apply(params, None, x)
print("input_shape",x.shape)
print("output_shape",o.shape)
end_time = time.time()
print(f"Run Time: {end_time - start_time:.6f} seconds")

output

initialization Time: 251.865844 seconds
Compilation Time: 255.010969 seconds
input_shape (1, 8, 8, 128)
output_shape (1, 144, 144, 32)
Run Time: 0.000671 seconds

for comparison, in pytorch:

Initialization Time: 0.033582 seconds
input_shape torch.Size([1, 128, 8, 8])
output_shape torch.Size([1, 32, 144, 144])
Run Time: 0.047478 seconds

Google colab notebook replicating the test:
https://colab.research.google.com/drive/15YkOuK0EjqZdBNaXpF2wpYexGqtjZjLr

@sokrypton
Copy link
Author

Ok, this appears to be a lax.conv_transpose issue. I tried running lax.conv_transpose directly and also see same issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant