From a00043e430cb3b50c52fee2bcdc6b4b74f378f4e Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 20 Sep 2024 20:33:35 -0400 Subject: [PATCH] Write the current version into model configs --- src/compressed_tensors/__init__.py | 1 + src/compressed_tensors/compressors/model_compressor.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/compressed_tensors/__init__.py b/src/compressed_tensors/__init__.py index 0833dd42..64a52dda 100644 --- a/src/compressed_tensors/__init__.py +++ b/src/compressed_tensors/__init__.py @@ -19,3 +19,4 @@ from .config import * from .quantization import QuantizationConfig, QuantizationStatus from .utils import * +from .version import * diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index ea0ee43c..28d1a7c3 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -22,6 +22,7 @@ import torch import transformers +import compressed_tensors from compressed_tensors.base import ( COMPRESSION_CONFIG_NAME, QUANTIZATION_CONFIG_NAME, @@ -368,6 +369,7 @@ def update_config(self, save_directory: str): config_data[COMPRESSION_CONFIG_NAME][ SPARSITY_CONFIG_NAME ] = sparsity_config_data + config_data[COMPRESSION_CONFIG_NAME]["version"] = compressed_tensors.__version__ with open(config_file_path, "w") as config_file: json.dump(config_data, config_file, indent=2, sort_keys=True)