Skip to content

Lora training issue #43

Description

@Lu-Gru

Hello,
Everything seams to be working correctly, i.e. running gradio / cli / tests all passes, but when I attempt to train LORA I do encounter issue, output below:

(venv) Y:\AI_ToolKit\SA3\stable-audio-3>python scripts/train_lora.py     --model medium-base     --data_dir Y:\AI_ToolKit\SA3\stable-audio-3/../my_data     --rank 16     --adapter_type dora-rows      --save_dir Y:\AI_ToolKit\SA3\stable-audio-3/../out_lora         --checkpoint_every 500  --demo_every 500        --logger "csv"  --name "tekLora"        --num_workers 1         --batch_size 1  --lr 0.0001     --steps 1000
Seed set to 42
Y:\AI_ToolKit\SA3\venv\lib\site-packages\torch\nn\utils\weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
  WeightNorm.apply(module, name, dim)
Found 5 files
LoRA config: rank=16, alpha=16, adapter_type=dora-rows
lora layers: 229
Demo sample 0: prompt=techno, fast seconds_total=16
Demo sample 1: prompt=techno, fast seconds_total=16
Demo sample 2: prompt=techno, fast seconds_total=16
Demo sample 3: prompt=techno, fast seconds_total=16
Using bfloat16 Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\trainer\configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:881: Checkpoint directory Y:\AI_ToolKit\SA3\out_lora exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\utilities\model_summary\model_summary.py:242: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name                   | Type                             | Params | Mode | FLOPs
-------------------------------------------------------------------------------------------
0 | diffusion              | ConditionedDiffusionModelWrapper | 2.3 B  | eval | 0
1 | diffusion.model        | DiTWrapper                       | 1.5 B  | eval | 0
2 | diffusion.conditioner  | MultiConditioner                 | 215 K  | eval | 0
3 | diffusion.pretransform | AutoencoderPretransform          | 852 M  | eval | 0
-------------------------------------------------------------------------------------------
21.6 M    Trainable params
2.3 B     Non-trainable params
2.3 B     Total params
9,308.321 Total estimated model params size (MB)
458       Modules in train mode
1510      Modules in eval mode
0         Total Flops
Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:429: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.
AttributeError: Can't pickle local object 'train.<locals>.<lambda>'
Traceback (most recent call last):
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\trainer\call.py", line 49, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 630, in _fit_impl
    self._run(model, ckpt_path=ckpt_path, weights_only=weights_only)
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1079, in _run
    results = self._run_stage()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1123, in _run_stage
    self.fit_loop.run()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 209, in run
    self.setup_data()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 276, in setup_data
    iter(self._data_fetcher)  # creates the iterator inside the fetcher
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\loops\fetchers.py", line 105, in __iter__
    super().__iter__()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\loops\fetchers.py", line 52, in __iter__
    self.iterator = iter(self.combined_loader)
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\utilities\combined_loader.py", line 351, in __iter__
    iter(iterator)
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\utilities\combined_loader.py", line 92, in __iter__
    super().__iter__()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\utilities\combined_loader.py", line 43, in __iter__
    self.iterators = [iter(iterable) for iterable in self.iterables]
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\utilities\combined_loader.py", line 43, in <listcomp>
    self.iterators = [iter(iterable) for iterable in self.iterables]
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\torch\utils\data\dataloader.py", line 493, in __iter__
    return self._get_iterator()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\torch\utils\data\dataloader.py", line 424, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\torch\utils\data\dataloader.py", line 1171, in __init__
    w.start()
  File "multiprocessing\process.py", line 121, in start
  File "multiprocessing\context.py", line 224, in _Popen
  File "multiprocessing\context.py", line 336, in _Popen
  File "multiprocessing\popen_spawn_win32.py", line 93, in __init__
  File "multiprocessing\reduction.py", line 60, in dump
AttributeError: Can't pickle local object 'train.<locals>.<lambda>'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "Y:\AI_ToolKit\SA3\stable-audio-3\scripts\train_lora.py", line 391, in <module>
    main()
  File "Y:\AI_ToolKit\SA3\stable-audio-3\scripts\train_lora.py", line 387, in main
    train(args)
  File "Y:\AI_ToolKit\SA3\stable-audio-3\scripts\train_lora.py", line 288, in train
    trainer.fit(training_wrapper, dataloader)
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 584, in fit
    call._call_and_handle_interrupt(
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\trainer\call.py", line 70, in _call_and_handle_interrupt
    trainer._teardown()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1106, in _teardown
    loop.teardown()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 516, in teardown
    self._data_fetcher.teardown()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\loops\fetchers.py", line 80, in teardown
    self.reset()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\loops\fetchers.py", line 142, in reset
    super().reset()
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\loops\fetchers.py", line 76, in reset
    self.length = sized_len(self.combined_loader)
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\lightning_fabric\utilities\data.py", line 52, in sized_len
    length = len(dataloader)  # type: ignore [arg-type]
  File "Y:\AI_ToolKit\SA3\venv\lib\site-packages\pytorch_lightning\utilities\combined_loader.py", line 358, in __len__
    raise RuntimeError("Please call `iter(combined_loader)` first.")
RuntimeError: Please call `iter(combined_loader)` first.

(venv) Y:\AI_ToolKit\SA3\stable-audio-3>cmd
Microsoft Windows [Version 10.0.19045.6216]
(c) Microsoft Corporation. All rights reserved.

(venv) Y:\AI_ToolKit\SA3\stable-audio-3>Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "multiprocessing\spawn.py", line 116, in spawn_main
  File "multiprocessing\spawn.py", line 126, in _main
EOFError: Ran out of input

To train I use fallowing command, on Windows:

call %CD%\venv\Scripts\activate.bat
cd stable-audio-3
python scripts/train_lora.py ^
    --model medium-base ^
    --data_dir %CD%/../my_data ^
    --rank 16 ^
    --adapter_type dora-rows ^
	--save_dir %CD%/../out_lora ^
	--checkpoint_every 500 ^
	--demo_every 500 ^
	--logger "csv" ^
	--name "tekLora" ^
	--num_workers 1 ^
	--batch_size 1 ^
	--lr 0.0001 ^
    --steps 1000

cmd
pause

Any advice more then welcome :)
Thanks in advance.

EDIT: did noticed that file count and output for prompts is incorrect. In above example there is 5 files is data but only 4 prompts are listed:

Found 5 files
LoRA config: rank=16, alpha=16, adapter_type=dora-rows
lora layers: 229
Demo sample 0: prompt=techno, fast seconds_total=16
Demo sample 1: prompt=techno, fast seconds_total=16
Demo sample 2: prompt=techno, fast seconds_total=16
Demo sample 3: prompt=techno, fast seconds_total=16

when deleting 1 file (with associated txt) output shows correctly:

Found 4 files
LoRA config: rank=16, alpha=16, adapter_type=dora-rows
lora layers: 229
Demo sample 0: prompt=techno, fast seconds_total=16
Demo sample 1: prompt=techno, fast seconds_total=16
Demo sample 2: prompt=techno, fast seconds_total=16
Demo sample 3: prompt=techno, fast seconds_total=16

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions