class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS = 8
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
revision: str | None
"""The optional model revision."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
allow_patterns_overrides: list[str] | None = None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = load_config.model_loader_extra_config
allowed_keys = {"enable_multithread_load", "num_threads"}
unexpected_keys = set(extra_config.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(
f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{unexpected_keys}"
)
def _prepare_weights(
self,
model_name_or_path: str,
revision: str | None,
fall_back_to_pt: bool,
allow_patterns_overrides: list[str] | None,
) -> tuple[str, list[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = (
maybe_download_from_modelscope(model_name_or_path, revision)
or model_name_or_path
)
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# First check for 'auto' format that mistral files format are present.
# This is to load mistral models with official format by default.
if load_format == "auto":
load_format = (
"mistral"
if len(
list_filtered_repo_files(
model_name_or_path=model_name_or_path,
allow_patterns=["consolidated*.safetensors"],
revision=revision,
)
)
> 0
else "hf"
)
# Some quantized models use .pt files for storing the weights.
if load_format == "hf":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors" or load_format == "fastsafetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "mistral":
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if allow_patterns_overrides is not None:
allow_patterns = allow_patterns_overrides
if not is_local:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
hf_folder = model_name_or_path
hf_weights_files: list[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file
)
else:
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path,
source.revision,
source.fall_back_to_pt,
source.allow_patterns_overrides,
)
if self.load_config.load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
source.model_or_path,
self.load_config.download_dir,
hf_folder,
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
if self.load_config.load_format == "fastsafetensors":
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.safetensors_load_strategy,
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
def get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
)
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(
Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()),
)
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(
model_config.model,
model_config.revision,
fall_back_to_pt=True,
allow_patterns_overrides=None,
)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao":
quant_config = get_quant_config(model_config, self.load_config)
if (
hasattr(quant_config, "is_checkpoint_torchao_serialized")
and quant_config.is_checkpoint_torchao_serialized
and torchao_version_at_least("0.15.0")
):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info_once(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights - self.counter_before_loading_weights,
scope="local",
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)