From 29f89cbe11eb933b76b86ad1cef00d4069d906f0 Mon Sep 17 00:00:00 2001 From: Abhinav Agarwalla Date: Thu, 4 Apr 2024 17:28:00 -0400 Subject: [PATCH] Update SparseGPT updates to respect base model's sparsity --- src/sparseml/modifiers/obcq/base.py | 1 + src/sparseml/modifiers/obcq/pytorch.py | 1 + .../modifiers/obcq/utils/sgpt_wrapper.py | 41 +++++++++++++++++-- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index f6e504e7b05..685694c8fae 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -56,6 +56,7 @@ class SparseGPTModifier(WandaPruningModifier): sparsity: Union[float, List[float]] = 0.0 dampening_frac: Optional[float] = 0.01 quantization_modifier_: Any = None + preserve_sparsity_mask: bool = False def on_initialize_structure(self, state: State, **kwargs): """ diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index de1eef74189..b588dfe0d0a 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -91,6 +91,7 @@ def _pruning_arguments(self, sparsity): "prunem": self.prunem_, "blocksize": self.block_size, "percdamp": self.dampening_frac, + "preserve_sparsity_mask": self.preserve_sparsity_mask, } def _compression_class(self): diff --git a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py index a911fa1c0c7..c2c5f36b544 100644 --- a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py +++ b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py @@ -84,6 +84,7 @@ def fasterprune( prunem: int = 0, blocksize: int = 128, percdamp: float = 0.01, + preserve_sparsity_mask: bool = False, ): """ Run pruning and quantization(if applicable) on the layer up to the target @@ -94,6 +95,7 @@ def fasterprune( :param prunem: M for N:M pruning :param blocksize: Number of columns to compress in one pass :param percdamp: Amount of dampening to apply to H, as a fraction of the + :param preserve_sparsity_mask: extend or ignore the base sparsity mask diagonal norm """ final_shape = self.layer.weight.shape @@ -123,6 +125,13 @@ def fasterprune( Hinv = self.H mask = None + if preserve_sparsity_mask: + # compute existing sparsity mask + mask = torch.where( + W == 0, + torch.tensor(1, dtype=torch.bool), + torch.tensor(0, dtype=torch.bool), + ) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): @@ -138,12 +147,32 @@ def fasterprune( if prunen == 0: if mask is not None: mask1 = mask[:, i1:i2] + if int(W1.numel() * sparsity) > mask1.sum(): + # target sparsity is higher than base sparsity, extend mask1 + tmp = ( + (~mask[:, i1:i2]) + * W1**2 + / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + ) + thresh = torch.sort(tmp.flatten())[0][ + int(tmp.numel() * sparsity) + ] + mask1 = tmp <= thresh + else: + raise ValueError( + "The target sparsity is lower than the sparsity " + "of the base model. Please retry " + "after turning preserve_sparsity_mask=False" + ) else: tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] mask1 = tmp <= thresh else: - mask1 = torch.zeros_like(W1) == 1 + if mask is not None: + mask1 = mask[:, i1:i2] + else: + mask1 = torch.zeros_like(W1) == 1 for i in range(count): w = W1[:, i] @@ -151,7 +180,8 @@ def fasterprune( if prunen != 0 and i % prunem == 0: tmp = ( - W1[:, i : (i + prunem)] ** 2 + (~mask[:, i : (i + prunem)]) + * W1[:, i : (i + prunem)] ** 2 / (torch.diag(Hinv1)[i : (i + prunem)].reshape((1, -1))) ** 2 ) mask1.scatter_( @@ -182,7 +212,12 @@ def fasterprune( W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_sparsity_mask: + # respect the sparsity of other groups + # really not needed, but kept for explicitness + W[:, i2:] -= (~mask[:, i2:]) * Err1.matmul(Hinv[i1:i2, i2:]) + else: + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item())