Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Avoid loading model weights before recipe application if any #2230

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

rahul-tuli
Copy link
Member

@rahul-tuli rahul-tuli commented Apr 8, 2024

Peviously when SparseAutoModelForCausalLM.from_pretrained(...) was called the weights were loaded in twice, once during model = super(AutoModelForCausalLM, cls).from_pretrained(...) and then again after recipe application, which is undesirable.

This PR updates the flow to use from_config(...) over from_pretrained, which initializes a model with init weight data, after recipe application the actual trained weights are loaded back in.

More info on from_config: https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#transformers.AutoModel.from_config

initial effort was to accomplish this with accelerate.init_empty weights but we run into https://discuss.huggingface.co/t/error-the-model-weights-are-not-tied-please-use-the-tie-weights-method-before-using-the-infer-auto-device-function-even-after-adding-model-tie-weights/46325 issue with quantized models.

Tests: Tested loading dense, sparse and quantized checkpoints which load just fine

Test script:

import time
from typing import List
from sparseml.transformers import SparseAutoModelForCausalLM
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--model-type", type=str, choices=["dense", "sparse", "quantized"], default="quantized")
parser.add_argument("--all", action="store_true")

BASE_MODEL = "Xenova/llama2.c-stories15M"

# Define the model paths for each model type
models = {
    "dense": "Xenova/llama2.c-stories15M",
    "sparse": "/home/rahul/projects/sparseml/local/local_output/sparse_model_80",
    "quantized": "mgoin/llama2.c-stories15M-quant-pt",
}

def load_and_time(model_path):
    start_time = time.time()
    SparseAutoModelForCausalLM.from_pretrained(model_path)
    end_time = time.time()
    return end_time - start_time

def load_weights(model_types: List[str]):
    return {
            model_type: load_and_time(models[model_type])
            for model_type in model_types
        }

    

def main(args):
    timings = ( 
               load_weights(model_types=list(models.keys()))
               if args.all 
               else load_weights(model_types=[args.model_type])
    )
    print(timings)

if __name__ == "__main__":
    args = parser.parse_args()
    main(args=args)
    

"""
Takes a loaded Pytorch model and applies any structural changes such as quantization
to the model, then reloads the model.

:param model: PyTorch model to apply structure to
:param recipe_path: path to recipe to apply to the model
:param model_path: path to model, used for reloading the state dict
:param reload_weights: flag to reload the weights after applying the recipe.
Dafault is True.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Dafault is True.
Default is True.

Copy link
Contributor

@dbogunowicz dbogunowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good!

@@ -130,12 +135,27 @@ def skip(*args, **kwargs):
compressor.overwrite_weights(model_path=model_path, model=model)

recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path)

# this must be done before recipe is applied
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, why? how does this modify the state of the model?

@dbogunowicz
Copy link
Contributor

Also @rahul-tuli, the correct implementantion of this PR should make this part of from_pretrained method:

def skip(*args, **kwargs):
    pass
# Skip the initializer step. This accelerates the loading
# of the models, especially for the quantized models
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

redundant!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants