Skip to content

Error when load FLUX model from existing checkpoint and use multiprocess #203

@hx89

Description

@hx89

We got the following error when we run FLUX training using one process per GPU and train_new_flux=False to load FLUX model from existing checkpoint:

File "/opt/maxdiffusion/src/maxdiffusion/trainers/flux_trainer.py", line 112, in start_training
flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/maxdiffusion/src/maxdiffusion/checkpointing/flux_checkpointer.py", line 111, in create_flux_state
flux_state = jax.device_put(flux_state, state_mesh_shardings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/api.py", line 2376, in device_put
out_flat = dispatch.device_put_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/core.py", line 502, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/core.py", line 520, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/core.py", line 525, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/core.py", line 1029, in process_primitive
return primitive.impl(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/dispatch.py", line 553, in _batched_device_put_impl
y = _device_put_impl(x, device=device, src=src, copy=cp)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/dispatch.py", line 542, in _device_put_impl
return _device_put_sharding_impl(x, aval, device, copy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/dispatch.py", line 479, in _device_put_sharding_impl
raise ValueError(
ValueError: device_put's second argument must be a Device or a Sharding which represents addressable devices, but got NamedSharding(mesh=Mesh('data': 1, 'fsdp': 8, 'tensor': 1, axis_types=(Auto, Auto, Auto)), spec=PartitionSpec('tensor',), memory_kind=device). Please pass device or Sharding which represents addressable devices.

The error is gone if we use train_new_flux=True.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions