LLM/LVM์ SFT, DPO, Pretrain์ผ๋ก ํ์ตํ๊ณ ์ ์ฒ๋ฆฌํ๋ ์ฝ๋๋ฅผ ์ ๋ฆฌํ repo. ์์ฃผ ์ฌ์ฉํ๋ ํ์ต ๋ฐฉ์์ ์ฝ๋๋ก ๊ตณํ๋ ๊ฒ์ด๊ณ , ๋งค๋ฒ ๋ง๋ค๋ 1ํ์ฉ ์ฝ๋๋ฅผ ์ต๋ํ ์ค์ด๊ณ GPU๋ฅผ ๋ญ๋น ์์ด ํจ์จ์ ์ผ๋ก ์ฌ์ฉํ๋๋ฐ ์ด์ ์ด ๋ง์ถฐ์ ธ ์๋ค.
- LLM ํ์ต์ ๋๋ฌด ๋นํจ์จ์ ์ด๋ค. GPU ์์์ ์ต๋๋ก ์ฐ๋ฉด์ ์ฒ๋ฆฌ๋์ ์ฌ๋ฆฌ๋ ๋ฐฉ๋ฒ๋ค์ ๊ตฌํ.
- ๋ด๊ฐ ์์ฃผ ํ๋ ํ์ต ์คํ์ผ์ ์ฝ๋๋ก ์ ๋ฆฌํด ๊ตณํ๋๋ค.
- 1ํ์ฉ ์์ ์ ์ต๋ํ ์ค์ธ๋ค.
{method}/ ์๋๋ ํญ์ ์ด๋ ๊ฒ ๋๋ค.
main.py(ํ์): args +def main+if "__main__" in __name__preprocessor.py: ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ- ๊ทธ ์ธ(callback ๋ฑ)๋
callbacks.py๊ฐ์ ํ์ผ๋ก๋ง ๋ถ๋ฆฌ
์์ ์
args.py,collator.py๋ก ์ชผ๊ฐฐ๋๋ ํ์ผ๊ณผ VSCode ์ฐฝ๋ง ๋๊ณ ํ๋์ ์ ๋ค์ด์๋ค. Python ๋ชจ๋ ์ฒ ํ๊ณผ๋ ์ด๊ธ๋์ง๋งmain.pyํตํฉ์ ํํ๋ค. ํ์ผ ์๋ฅผ ์ต์ํํ๋ ๊ฒ ๋ชฉ์ ์ด๋ค.
์์์๋ถํฐ args โ def main โ def train / valid / predict ์์๋ก ๋๋ค.
train/valid ๋ก์ง์ main ์์ ๋ฐ์ง ์๊ณ ํจ์๋ก ๋บ ์ด์ ๋, ๋ฐ์๋๋ฉด ์ฝ๋ ๋ถ๋ฆฌ๊ฐ ์ ๋ผ์ ํ๋์ ์ ๋ณด์ด๊ธฐ ๋๋ฌธ์ด๋ค.
if "__main__" in __name__: ๋ธ๋ก์ logging, setproctitle, set_seed, args ํ์ฑ์ฒ๋ผ ํ ์ค๋ก ๋๋๋ ์ก์ค์ ์ ์ฉ ์ฅ์๋ค.
args๋ฅผ model/data/train์ผ๋ก ๋๋์ง ์๊ณ ํ๋๋ก ์์ํด ํฉ์ณค๋ค.
๋๋ ๋๋ฉด wandb์ train args๋ง ๊ธฐ๋ก๋ผ์ model, dataset ์ค์ ์ด ์ฌํ์ฑ ์ถ์ ์์ ๋๋ฝ๋๋ค. ์คํ ๊ด๋ฆฌ๋ฅผ wandb๋ก ํ๋ค ๋ณด๋ ์ ๋ถ ํ ๊ณณ์ ๊ธฐ๋ก๋๊ฒ ํฉ์ณค๋ค.
ํ์ต๋ง๋ค ๋ณ๋ ์ ์ฒ๋ฆฌ ์ฝ๋๋ฅผ ์ง๋ ๋์ , ์์ฃผ ์ฐ๋ ๊ธฐ๋ฅ์ args ๋งต์ผ๋ก ์ ๊ทํํ๋ค.
dataset_name_map: subset(name) ๋งคํdataset_prefix: split โ ์ญํ (train/valid/test) ๋งคํdataset_truncate_map: ์ํ ์ ์กฐ์ / ๋ถํdataset_files_map: ๋ก์ปฌ ํ์ผ ๋งคํ
์ฌ๋ฌ dataset์ ์์ด ์ฐ๋ฉด split, ๋ถํ , ์ด์์น ์ฒ๋ฆฌ ๊ฐ์ 1ํ์ฉ ์์ ์ด ๋งค๋ฒ ๋ฐ๋ผ๋ถ๋๋ฐ, ์ด๊ฑธ ๋งต์ผ๋ก ๋นผ์ args๋ง ๋ฐ๊ฟ ๋๋ผ ์ ์๊ฒ ํ๋ค.
๊ถ์ฅ ์ ๋ ฅ ์ปฌ๋ผ:
- SFT:
conversations,prompt,answer,images - Pretrain:
corpus,sentence_ls
๋๋ถ๋ถ์ args๋ HF TrainingArguments๋ฅผ ๊ทธ๋๋ก ๋ฐ๋ฅธ๋ค. SFT๋ ์ถ๊ฐ๋ก TRL SFTConfig, ModelConfig, DPO๋ DPOConfig, ModelConfig๋ฅผ ์์. ๋ด๊ฐ ์ถ๊ฐํ args ๊ธฐ๋ฅ์ ์๋์ ๊ฐ๋ค.
๋ฐ์ดํฐ
dataset_repo_ls: ํ์ต์ ์ธ dataset repo ๋ชฉ๋กdataset_name_map/dataset_prefix/dataset_truncate_map/dataset_files_map: ์ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ์ฐธ๊ณdataset_batch_size: ์ ์ฒ๋ฆฌ ๋ฐฐ์น ํฌ๊ธฐ
๋ชจ๋ธ/ํ ํฌ๋์ด์
config_kwargs: AutoConfig์ ๋๊ธธ ์ถ๊ฐ ์ธ์tokenizer_kwargs(SFT) /processor_kwargs(DPO): ํ ํฌ๋์ด์ , ํ๋ก์ธ์ ์ถ๊ฐ ์ธ์chat_template_path(DPO): chat template ์ง์ lora_kwargs: LoraConfig์ ๋๊ธธ ์ถ๊ฐ ์ธ์ (peft_kwargs๋ก LoRA ์ธ PEFT ํ์ ํ์ฅ๋ ๊ฐ๋ฅ)
ํ์ต/ํ๊ฐ
dataset_type(SFT): sft / pretrain / dpo ์ ํpacking/packing_strategy/eval_packing: packing ์ ์ดeval_harness_tasks: ํ์ต ์ค lm-eval-harness๋ก ํ๊ฐํ ํ์คํฌ ๋ชฉ๋กgpu_mem_check: forward/backward/optimizer ๊ตฌ๊ฐ๋ณ GPU peak memory ๋ก๊น
์ฐธ๊ณ ๋ก wandb๋ก reportํ๋ฉด ์ฝ๋ ์ค๋ ์ท์ ์ํฐํฉํธ๋ก ๊ฐ์ด ์ ์ฅํด ์ฌํ์ฑ์ ๋จ๊ธด๋ค.
์ฝ๋ฐฑ์ src/callbacks.py์ ๋ชจ์ฌ ์๊ณ , SFT, DPO๊ฐ ๊ณต์ฉํ๋ค. ํด๋น args๊ฐ ์ผ์ง ๋๋ง ๋ถ๋๋ค.
EvalHarnessCallBack(eval_harness_tasks): ํ์ต ์ค eval ์์ ์ lm-eval-harness๋ฅผ ๊ฐ์ด ๋๋ฆฐ๋ค. ์์ ์ ํ์ต์ด ๋๋ ๋ค checkpoint๋ฅผ ํ๋์ฉ lm eval ๋๋ ธ๋๋ฐ, ๋๋ฌด ๋ฒ๊ฑฐ๋กญ๊ณ ์๊ฐ๋ ๋ง์ด ๋จน์๋ค. trainer์ eval/predict loop๊ฐ ์๋ ์ด์ ๊ฐ ์๋ํํด ์ฌ๋ ๊ณต์๋ฅผ ์ค์ด๋ ๊ฒ์ด๋ผ์ lm eval๋ ๊ทธ ์์ผ๋ก ํฌํจ์ํด.WandbCodeArtifactCallback(report_to์ wandb): base code ๋๋น ๋ญ๊ฐ ๋ฐ๋์๊ณ ์คํ์ด ์ด๋ป๊ฒ ๋๋์ง๋ฅผ artifact๋ก ๋จ๊ธด๋ค. ์ฌํ์ฑ, ๋ฒ์ ๊ด๋ฆฌ๋ฅผ wandb๋ก ํ๊ธฐ ๋๋ฌธ์ด๋ค.GpuMemoryCallback(gpu_mem_check): ๋จ์ memory ๋๋ฒ๊น ์ฉ. ๊ตฌ๊ฐ๋ณ GPU peak memory๋ฅผ ์ฐ๋๋ค.
SFT ์คํฌ๋ฆฝํธ ํ๋๋ก sft์ pretrain์ ๋ ๋ค ๋๋ฆฐ๋ค. dataset_type๋ง sft / pretrain์ผ๋ก ๋ฐ๊พธ๋ฉด ๋๋ค.
์ ์ฒ๋ฆฌ(assistant-only ๋ผ๋ฒจ๋ง vs ํต๋ฌธ์ฅ ํ ํฐํ)์ collator์ ๋ผ๋ฒจ ์ฒ๋ฆฌ(labels vs input_ids)๊ฐ ์ด ๊ฐ์ ๋ฐ๋ผ ๊ฐ๋ฆด ๋ฟ, packing, ์บ์, ๋ฐ์ดํฐ ๋งต ๊ฐ์ ๋๋จธ์ง ํ์ดํ๋ผ์ธ์ ๊ทธ๋๋ก ๊ณต์ ํ๋ค. ๊ทธ๋์ ๊ฐ์ ๋ฐ์ดํฐ ๊ตฌ์ฑ ์์์ pretrain ํ sft๋ก ์ด์ด๊ฐ๋ ์์ ์คํ์ args๋ง ๋ฐ๊ฟ์ ํ ์ ์๋ค.
์ ์ฒ๋ฆฌ ์บ์๋ ์ต์ด 1ํ ์คํ์์๋ง ์์ฑ๋๋ค. ์ต์ด 1ํ๋ main_process_first ์ปจํ
์คํธ๋ก main process์์๋ง ๋ง๋ค๊ณ , ์ดํ ์คํ์ datasets๊ฐ ๊ทธ ์บ์๋ฅผ ๊ทธ๋๋ก ๋ก๋ํ๋ค.
์บ์ ์์น๋ ์๋ณธ ๋ฐ์ดํฐ์ ๋ถ๋ฆฌํ๊ณ , ์ํ(๋ชจ๋ธ/repo/run/max_length)๋ณ๋ก ๋๋ ์ ์ฅํ๋ค.
datasets๋ arrow ์บ์๋ฅผ ์๋ณธ๊ณผ ๊ฐ์ ํด๋์ ์๋๋ฐ, ์บ์๋ฅผ ์ ๋ฆฌํ๋ค ์๋ณธ๊น์ง ์ง์ฐ๋ ์ฌ๊ณ ๊ฐ ์ฆ์๋ค. ๊ทธ๋์ ๋ถ๋ฆฌํ๋ค.
git clone https://github.com/jp1924/LLM.git
docker compose up -d # ๋์ปค ํ๊ฒฝ ๊ตฌ์ฑ
docker compose stop # ๋์ปค ํ๊ฒฝ ์ข
๋ฃ
Docker/Compose: 27.3.1-build ce12230, v2.29.7 ํจํค์ง ๋ฒ์ ์ด ๋นจ๋ฆฌ ๋ฐ๋์ด์ ์คํ ์ ๋๋ฉด Dockerfile ์์ ์ด ํ์ํ ์ ์๋ค.
ํ์ต ์คํ์ scripts/ ์๋ ์คํฌ๋ฆฝํธ๋ฅผ ์ด๋ค.
bash scripts/run_zero3_sft.sh
bash scripts/run_zero3_dpo.sh
pypi๋ก ๋ฐ์ผ๋ฉด ์ปดํ์ผ์ ์ค๋ ๊ฑธ๋ฆฌ์ง๋ง, flash-attn repo๊ฐ ์ฃผ๋ whl์ ์ปดํ์ผ์ด ๋๋ ์์ด ๋น ๋ฅด๋ค. ๋จผ์ ํ๊ฒฝ ๋ฒ์ ์ ํ์ธํ๋ค.
python -c "import torch; print(torch.__version__, torch.version.cuda, torch.compiled_with_cxx11_abi())"
์ดํ flash-attn repo์์ ํ๊ฒฝ์ ๋ง๋ whl์ ๋ฐ์ ์ค์นํ๋ฉด ๋๋ค.
๋ด๊ฐ ์ฃผ๋ก ์์ ํ๋ ๋ฐฉ์. ๊ทธ๋ฅ ์ฐธ๊ณ ์ฉ์.
- docker ํ๊ฒฝ
/root/workspace์์์ ์์ ํ๋ค. ์ด๊ธฐ ํ์ด์ฌ ์ธํ ์uv sync๋ก ๋ง์ถ๋ค. - ์ฃผ๋ก VSCode + tmux๋ก ์์
.
.vscode์ ์์ฃผ ์ฐ๋ vscode extensions๊ณผ debug์ ํ์ํ ์ค์ ์launch.json์ ์์น. - tmux, bashrc, vi ๊ฐ์ ๊ฐ๋ฐํ๊ฒฝ ๊ตฌ์ฑ ํ์ผ์
envํด๋์ - ํ์ต ๋ก๊ทธ๋
tee๋กlogs/{method}.log์ ๋จ๊ธด๋ค.
์คํ ๊ด๋ฆฌ๋ wandb๋ก ํ๋ค. ๊ทธ๋์ config/{method}.yaml์ ์คํ๋ง๋ค ๋ณต์ฌํ๊ฑฐ๋ ์ด๋ฆ์ ๋ฐ๊ฟ ์ฐจ์ด๋ฅผ ์ค๋ช
ํ์ง ์๊ณ , WANDB_NOTES env๋ก ์คํ๋ณ ํน์ง์ ์ ๋ ๊ฑธ ์ ํธํ๋ค.
๋ก๊ทธ๋ {method}.log ํ๋๋ก ๋๋ค. ๋ณต์ฌ๋ณธ์ ๋ง๋ค๋ฉด ๋์ค์ diff, ๋ฒ์ ๊ด๋ฆฌ๊ฐ ์ด๋ ค์์ง๊ธฐ ๋๋ฌธ์ด๋ค.