Skip to content

Commit baa9b58

Browse files
sayakpaulDN6
andauthored
[core] parallel loading of shards (#12028)
* checking. * checking * checking * up * up * up * Apply suggestions from code review Co-authored-by: Dhruv Nair <[email protected]> * up * up * fix * review feedback. --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent da096a4 commit baa9b58

File tree

10 files changed

+251
-66
lines changed

10 files changed

+251
-66
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
if is_accelerate_available():
6363
from accelerate import dispatch_model, init_empty_weights
6464

65-
from ..models.modeling_utils import load_model_dict_into_meta
65+
from ..models.model_loading_utils import load_model_dict_into_meta
6666

6767
if is_torch_version(">=", "1.9.0") and is_accelerate_available():
6868
_LOW_CPU_MEM_USAGE_DEFAULT = True

src/diffusers/loaders/single_file_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
if is_accelerate_available():
5656
from accelerate import init_empty_weights
5757

58-
from ..models.modeling_utils import load_model_dict_into_meta
58+
from ..models.model_loading_utils import load_model_dict_into_meta
5959

6060
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6161

src/diffusers/loaders/transformer_flux.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
ImageProjection,
1818
MultiIPAdapterImageProjection,
1919
)
20-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
20+
from ..models.model_loading_utils import load_model_dict_into_meta
21+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
2122
from ..utils import is_accelerate_available, is_torch_version, logging
2223
from ..utils.torch_utils import empty_device_cache
2324

src/diffusers/loaders/transformer_sd3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
1818
from ..models.embeddings import IPAdapterTimeImageProjection
19-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
19+
from ..models.model_loading_utils import load_model_dict_into_meta
20+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
2021
from ..utils import is_accelerate_available, is_torch_version, logging
2122
from ..utils.torch_utils import empty_device_cache
2223

src/diffusers/loaders/unet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
IPAdapterPlusImageProjection,
3131
MultiIPAdapterImageProjection,
3232
)
33-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
33+
from ..models.model_loading_utils import load_model_dict_into_meta
34+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
3435
from ..utils import (
3536
USE_PEFT_BACKEND,
3637
_get_model_file,

src/diffusers/models/model_loading_utils.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import functools
1718
import importlib
1819
import inspect
1920
import math
2021
import os
2122
from array import array
2223
from collections import OrderedDict, defaultdict
24+
from concurrent.futures import ThreadPoolExecutor, as_completed
2325
from pathlib import Path
2426
from typing import Dict, List, Optional, Union
2527
from zipfile import is_zipfile
@@ -31,6 +33,7 @@
3133

3234
from ..quantizers import DiffusersQuantizer
3335
from ..utils import (
36+
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
3437
GGUF_FILE_EXTENSION,
3538
SAFE_WEIGHTS_INDEX_NAME,
3639
SAFETENSORS_FILE_EXTENSION,
@@ -310,6 +313,161 @@ def load_model_dict_into_meta(
310313
return offload_index, state_dict_index
311314

312315

316+
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
317+
"""
318+
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
319+
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
320+
parameters.
321+
322+
"""
323+
if model_to_load.device.type == "meta":
324+
return False
325+
326+
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
327+
return False
328+
329+
# Some models explicitly do not support param buffer assignment
330+
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
331+
logger.debug(
332+
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
333+
)
334+
return False
335+
336+
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
337+
first_key = next(iter(model_to_load.state_dict().keys()))
338+
if start_prefix + first_key in state_dict:
339+
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
340+
341+
return False
342+
343+
344+
def _load_shard_file(
345+
shard_file,
346+
model,
347+
model_state_dict,
348+
device_map=None,
349+
dtype=None,
350+
hf_quantizer=None,
351+
keep_in_fp32_modules=None,
352+
dduf_entries=None,
353+
loaded_keys=None,
354+
unexpected_keys=None,
355+
offload_index=None,
356+
offload_folder=None,
357+
state_dict_index=None,
358+
state_dict_folder=None,
359+
ignore_mismatched_sizes=False,
360+
low_cpu_mem_usage=False,
361+
):
362+
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
363+
mismatched_keys = _find_mismatched_keys(
364+
state_dict,
365+
model_state_dict,
366+
loaded_keys,
367+
ignore_mismatched_sizes,
368+
)
369+
error_msgs = []
370+
if low_cpu_mem_usage:
371+
offload_index, state_dict_index = load_model_dict_into_meta(
372+
model,
373+
state_dict,
374+
device_map=device_map,
375+
dtype=dtype,
376+
hf_quantizer=hf_quantizer,
377+
keep_in_fp32_modules=keep_in_fp32_modules,
378+
unexpected_keys=unexpected_keys,
379+
offload_folder=offload_folder,
380+
offload_index=offload_index,
381+
state_dict_index=state_dict_index,
382+
state_dict_folder=state_dict_folder,
383+
)
384+
else:
385+
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
386+
387+
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
388+
return offload_index, state_dict_index, mismatched_keys, error_msgs
389+
390+
391+
def _load_shard_files_with_threadpool(
392+
shard_files,
393+
model,
394+
model_state_dict,
395+
device_map=None,
396+
dtype=None,
397+
hf_quantizer=None,
398+
keep_in_fp32_modules=None,
399+
dduf_entries=None,
400+
loaded_keys=None,
401+
unexpected_keys=None,
402+
offload_index=None,
403+
offload_folder=None,
404+
state_dict_index=None,
405+
state_dict_folder=None,
406+
ignore_mismatched_sizes=False,
407+
low_cpu_mem_usage=False,
408+
):
409+
# Do not spawn anymore workers than you need
410+
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
411+
412+
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
413+
414+
error_msgs = []
415+
mismatched_keys = []
416+
417+
load_one = functools.partial(
418+
_load_shard_file,
419+
model=model,
420+
model_state_dict=model_state_dict,
421+
device_map=device_map,
422+
dtype=dtype,
423+
hf_quantizer=hf_quantizer,
424+
keep_in_fp32_modules=keep_in_fp32_modules,
425+
dduf_entries=dduf_entries,
426+
loaded_keys=loaded_keys,
427+
unexpected_keys=unexpected_keys,
428+
offload_index=offload_index,
429+
offload_folder=offload_folder,
430+
state_dict_index=state_dict_index,
431+
state_dict_folder=state_dict_folder,
432+
ignore_mismatched_sizes=ignore_mismatched_sizes,
433+
low_cpu_mem_usage=low_cpu_mem_usage,
434+
)
435+
436+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
437+
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
438+
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
439+
for future in as_completed(futures):
440+
result = future.result()
441+
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
442+
error_msgs += _error_msgs
443+
mismatched_keys += _mismatched_keys
444+
pbar.update(1)
445+
446+
return offload_index, state_dict_index, mismatched_keys, error_msgs
447+
448+
449+
def _find_mismatched_keys(
450+
state_dict,
451+
model_state_dict,
452+
loaded_keys,
453+
ignore_mismatched_sizes,
454+
):
455+
mismatched_keys = []
456+
if ignore_mismatched_sizes:
457+
for checkpoint_key in loaded_keys:
458+
model_key = checkpoint_key
459+
# If the checkpoint is sharded, we may not have the key here.
460+
if checkpoint_key not in state_dict:
461+
continue
462+
463+
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
464+
mismatched_keys.append(
465+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
466+
)
467+
del state_dict[checkpoint_key]
468+
return mismatched_keys
469+
470+
313471
def _load_state_dict_into_model(
314472
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
315473
) -> List[str]:

0 commit comments

Comments
 (0)