-
Notifications
You must be signed in to change notification settings - Fork 39
Description
We got the following error when we run FLUX training using one process per GPU and train_new_flux=False to load FLUX model from existing checkpoint:
File "/opt/maxdiffusion/src/maxdiffusion/trainers/flux_trainer.py", line 112, in start_training
flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/maxdiffusion/src/maxdiffusion/checkpointing/flux_checkpointer.py", line 111, in create_flux_state
flux_state = jax.device_put(flux_state, state_mesh_shardings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/api.py", line 2376, in device_put
out_flat = dispatch.device_put_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/core.py", line 502, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/core.py", line 520, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/core.py", line 525, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/core.py", line 1029, in process_primitive
return primitive.impl(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/dispatch.py", line 553, in _batched_device_put_impl
y = _device_put_impl(x, device=device, src=src, copy=cp)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/dispatch.py", line 542, in _device_put_impl
return _device_put_sharding_impl(x, aval, device, copy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/dispatch.py", line 479, in _device_put_sharding_impl
raise ValueError(
ValueError: device_put's second argument must be a Device or a Sharding which represents addressable devices, but got NamedSharding(mesh=Mesh('data': 1, 'fsdp': 8, 'tensor': 1, axis_types=(Auto, Auto, Auto)), spec=PartitionSpec('tensor',), memory_kind=device). Please pass device or Sharding which represents addressable devices.
The error is gone if we use train_new_flux=True.