Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The inference speed on the mobile end is a bit slow #928

Open
Gratifyyy opened this issue Sep 10, 2024 · 7 comments
Open

The inference speed on the mobile end is a bit slow #928

Gratifyyy opened this issue Sep 10, 2024 · 7 comments
Labels
question Further information is requested

Comments

@Gratifyyy
Copy link

Gratifyyy commented Sep 10, 2024

Question

If it is a mobile device that does not support WebGPU, how can we improve the inference speed of the model? I have tried WebWorker, but the results were not satisfactory

@Gratifyyy Gratifyyy added the question Further information is requested label Sep 10, 2024
@Gratifyyy Gratifyyy changed the title 移动端推理速度有点慢 The inference speed on the mobile end is a bit slow Sep 10, 2024
@gyagp
Copy link

gyagp commented Sep 10, 2024

Did you enable wasm simd and multi-threads? If not, you may give a try with env.backends.onnx.wasm.simd = true and env.backends.onnx.wasm.numThreads = xxx (a reasonable value according to your core number).

@flatsiedatsie
Copy link

I believe multi-threading al already handled automatically: #882

I can't imagine SIMD not being the same?

@Gratifyyy
Copy link
Author

Both are turned on but no real changes have occurred

import { AutoModel, AutoProcessor, env, RawImage } from '@xenova/transformers';

env.allowLocalModels = false;

env.backends.onnx.wasm.proxy = true;
env.backends.onnx.wasm.simd = true
env.backends.onnx.wasm.numThreads = 4

const model = await AutoModel.from_pretrained('briaai/RMBG-1.4', {
  config: { model_type: 'custom' },
});

const processor = await AutoProcessor.from_pretrained('briaai/RMBG-1.4', {
  config: {
      do_normalize: true,
      do_pad: false,
      do_rescale: true,
      do_resize: true,
      image_mean: [0.5, 0.5, 0.5],
      feature_extractor_type: "ImageFeatureExtractor",
      image_std: [1, 1, 1],
      resample: 2,
      rescale_factor: 0.00392156862745098,
      size: { width: 1024, height: 1024 },
  }
});

export const predict = async (url: string) =>  {
  // Read image
  const image = await RawImage.fromURL(url);

  // Preprocess image
  const { pixel_values } = await processor(image);

  // Predict alpha matte
  const { output } = await model({ input: pixel_values });

  const pixelData = image.rgba();
  // Resize mask back to original size
  const mask = await RawImage.fromTensor(output[0].mul(255).to('uint8')).resize(image.width, image.height);
  // Convert alpha channel to 4th channel
  for (let i = 0; i < mask.data.length; ++i) {
    pixelData.data[4 * i + 3] = mask.data[i];
  }
  return (pixelData.toSharp());
}

@flatsiedatsie
Copy link

@xenova Does SIMD have to be enabled manually?

@gyagp
Copy link

gyagp commented Sep 11, 2024

Usually we'd see perf change if playing with the env.backends.onnx.wasm.numThreads (better or worse, and larger number doesn't always mean better). But to make multi-threads work, your web server needs to be cross-origin isolated (https://web.dev/articles/coop-coep). You may open your console, and check if "crossOriginIsolated" is true or false.

@gyagp
Copy link

gyagp commented Sep 11, 2024

@xenova Does SIMD have to be enabled manually?

Transformers.js doesn't have to do anything special to set both simd and numThreads. By default, ORT has simd enabled and numThreads=min( 4, ceil(cpu_core_num / 2)) if crossOriginIsolated is true (Thanks @fs-eire for confirmation).

@Gratifyyy
Copy link
Author

Usually we'd see perf change if playing with the env.backends.onnx.wasm.numThreads (better or worse, and larger number doesn't always mean better). But to make multi-threads work, your web server needs to be cross-origin isolated (https://web.dev/articles/coop-coep). You may open your console, and check if "crossOriginIsolated" is true or false.

Thanks for the reminder. I just checked and crossOriginIsolated is false. I will try to change it to see if it has any effect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants