-
Notifications
You must be signed in to change notification settings - Fork 39
Description
When testing FLUX GPU training with FSDP=8 using Madiffusion main branch and the same run script that worked before, we got OOM error.
The log had multiple error messages indicating XLA is allocating several large memories for some collectives such as:
allocator.cc:62 NCCL WARN Cuda failure 2 'out of memory'
external/xla/xla/stream_executor/integrations/stream_executor_allocator.cc:66] could not allocate collective of size: 20289079296
Our run command is:
python3 src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_schnell.yml hardware=gpu run_name=flux attention=cudnn_flash_te max_train_steps=20 enable_profiler=True profiler_steps=5 skip_first_n_steps_for_profiler=10 profiler=xplane train_new_flux=True
I have tried various XLA_PYTHON_CLIENT_MEM_FRACTION values but didn't help.