Skip to content

jp1924/LLM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

210 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

LLM

LLM/LVM์„ SFT, DPO, Pretrain์œผ๋กœ ํ•™์Šตํ•˜๊ณ  ์ „์ฒ˜๋ฆฌํ•˜๋Š” ์ฝ”๋“œ๋ฅผ ์ •๋ฆฌํ•œ repo. ์ž์ฃผ ์‚ฌ์šฉํ•˜๋Š” ํ•™์Šต ๋ฐฉ์‹์„ ์ฝ”๋“œ๋กœ ๊ตณํ˜€๋‘” ๊ฒƒ์ด๊ณ , ๋งค๋ฒˆ ๋งŒ๋“ค๋˜ 1ํšŒ์šฉ ์ฝ”๋“œ๋ฅผ ์ตœ๋Œ€ํ•œ ์ค„์ด๊ณ  GPU๋ฅผ ๋‚ญ๋น„ ์—†์ด ํšจ์œจ์ ์œผ๋กœ ์‚ฌ์šฉํ•˜๋Š”๋ฐ ์ดˆ์ ์ด ๋งž์ถฐ์ ธ ์žˆ๋‹ค.

๊ฐœ๋ฐœ ์ด์œ 

  • LLM ํ•™์Šต์€ ๋„ˆ๋ฌด ๋น„ํšจ์œจ์ ์ด๋‹ค. GPU ์ž์›์„ ์ตœ๋Œ€๋กœ ์“ฐ๋ฉด์„œ ์ฒ˜๋ฆฌ๋Ÿ‰์„ ์˜ฌ๋ฆฌ๋Š” ๋ฐฉ๋ฒ•๋“ค์„ ๊ตฌํ˜„.
  • ๋‚ด๊ฐ€ ์ž์ฃผ ํ•˜๋Š” ํ•™์Šต ์Šคํƒ€์ผ์„ ์ฝ”๋“œ๋กœ ์ •๋ฆฌํ•ด ๊ตณํ˜€๋‘”๋‹ค.
  • 1ํšŒ์šฉ ์ž‘์—…์„ ์ตœ๋Œ€ํ•œ ์ค„์ธ๋‹ค.

๋””๋ ‰ํ„ฐ๋ฆฌ ๊ทœ์น™

{method}/ ์•„๋ž˜๋Š” ํ•ญ์ƒ ์ด๋ ‡๊ฒŒ ๋‘”๋‹ค.

  • main.py (ํ•„์ˆ˜): args + def main + if "__main__" in __name__
  • preprocessor.py: ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ
  • ๊ทธ ์™ธ(callback ๋“ฑ)๋Š” callbacks.py ๊ฐ™์€ ํŒŒ์ผ๋กœ๋งŒ ๋ถ„๋ฆฌ

์˜ˆ์ „์—” args.py, collator.py๋กœ ์ชผ๊ฐฐ๋”๋‹ˆ ํŒŒ์ผ๊ณผ VSCode ์ฐฝ๋งŒ ๋Š˜๊ณ  ํ•œ๋ˆˆ์— ์•ˆ ๋“ค์–ด์™”๋‹ค. Python ๋ชจ๋“ˆ ์ฒ ํ•™๊ณผ๋Š” ์–ด๊ธ‹๋‚˜์ง€๋งŒ main.py ํ†ตํ•ฉ์„ ํƒํ–ˆ๋‹ค. ํŒŒ์ผ ์ˆ˜๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ๊ฒŒ ๋ชฉ์ ์ด๋‹ค.

main.py ์ž‘์„ฑ ๊ทœ์น™

์œ„์—์„œ๋ถ€ํ„ฐ args โ†’ def main โ†’ def train / valid / predict ์ˆœ์„œ๋กœ ๋‘”๋‹ค. train/valid ๋กœ์ง์„ main ์•ˆ์— ๋ฐ•์ง€ ์•Š๊ณ  ํ•จ์ˆ˜๋กœ ๋บ€ ์ด์œ ๋Š”, ๋ฐ•์•„๋‘๋ฉด ์ฝ”๋“œ ๋ถ„๋ฆฌ๊ฐ€ ์•ˆ ๋ผ์„œ ํ•œ๋ˆˆ์— ์•ˆ ๋ณด์ด๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

if "__main__" in __name__: ๋ธ”๋ก์€ logging, setproctitle, set_seed, args ํŒŒ์‹ฑ์ฒ˜๋Ÿผ ํ•œ ์ค„๋กœ ๋๋‚˜๋Š” ์žก์„ค์ • ์ „์šฉ ์žฅ์†Œ๋‹ค.

๋‹จ์ผ train_args

args๋ฅผ model/data/train์œผ๋กœ ๋‚˜๋ˆ„์ง€ ์•Š๊ณ  ํ•˜๋‚˜๋กœ ์ƒ์†ํ•ด ํ•ฉ์ณค๋‹ค.

๋‚˜๋ˆ ๋‘๋ฉด wandb์—” train args๋งŒ ๊ธฐ๋ก๋ผ์„œ model, dataset ์„ค์ •์ด ์žฌํ˜„์„ฑ ์ถ”์ ์—์„œ ๋ˆ„๋ฝ๋œ๋‹ค. ์‹คํ—˜ ๊ด€๋ฆฌ๋ฅผ wandb๋กœ ํ•˜๋‹ค ๋ณด๋‹ˆ ์ „๋ถ€ ํ•œ ๊ณณ์— ๊ธฐ๋ก๋˜๊ฒŒ ํ•ฉ์ณค๋‹ค.

๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ

ํ•™์Šต๋งˆ๋‹ค ๋ณ„๋„ ์ „์ฒ˜๋ฆฌ ์ฝ”๋“œ๋ฅผ ์งœ๋Š” ๋Œ€์‹ , ์ž์ฃผ ์“ฐ๋Š” ๊ธฐ๋Šฅ์„ args ๋งต์œผ๋กœ ์ •๊ทœํ™”ํ–ˆ๋‹ค.

  • dataset_name_map: subset(name) ๋งคํ•‘
  • dataset_prefix: split โ†’ ์—ญํ• (train/valid/test) ๋งคํ•‘
  • dataset_truncate_map: ์ƒ˜ํ”Œ ์ˆ˜ ์กฐ์ ˆ / ๋ถ„ํ• 
  • dataset_files_map: ๋กœ์ปฌ ํŒŒ์ผ ๋งคํ•‘

์—ฌ๋Ÿฌ dataset์„ ์„ž์–ด ์“ฐ๋ฉด split, ๋ถ„ํ• , ์ด์ƒ์น˜ ์ฒ˜๋ฆฌ ๊ฐ™์€ 1ํšŒ์šฉ ์ž‘์—…์ด ๋งค๋ฒˆ ๋”ฐ๋ผ๋ถ™๋Š”๋ฐ, ์ด๊ฑธ ๋งต์œผ๋กœ ๋นผ์„œ args๋งŒ ๋ฐ”๊ฟ” ๋๋‚ผ ์ˆ˜ ์žˆ๊ฒŒ ํ–ˆ๋‹ค.

๊ถŒ์žฅ ์ž…๋ ฅ ์ปฌ๋Ÿผ:

  • SFT: conversations, prompt, answer, images
  • Pretrain: corpus, sentence_ls

Args

๋Œ€๋ถ€๋ถ„์˜ args๋Š” HF TrainingArguments๋ฅผ ๊ทธ๋Œ€๋กœ ๋”ฐ๋ฅธ๋‹ค. SFT๋Š” ์ถ”๊ฐ€๋กœ TRL SFTConfig, ModelConfig, DPO๋Š” DPOConfig, ModelConfig๋ฅผ ์ƒ์†. ๋‚ด๊ฐ€ ์ถ”๊ฐ€ํ•œ args ๊ธฐ๋Šฅ์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

๋ฐ์ดํ„ฐ

  • dataset_repo_ls: ํ•™์Šต์— ์“ธ dataset repo ๋ชฉ๋ก
  • dataset_name_map / dataset_prefix / dataset_truncate_map / dataset_files_map: ์œ„ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ์ฐธ๊ณ 
  • dataset_batch_size: ์ „์ฒ˜๋ฆฌ ๋ฐฐ์น˜ ํฌ๊ธฐ

๋ชจ๋ธ/ํ† ํฌ๋‚˜์ด์ €

  • config_kwargs: AutoConfig์— ๋„˜๊ธธ ์ถ”๊ฐ€ ์ธ์ž
  • tokenizer_kwargs(SFT) / processor_kwargs(DPO): ํ† ํฌ๋‚˜์ด์ €, ํ”„๋กœ์„ธ์„œ ์ถ”๊ฐ€ ์ธ์ž
  • chat_template_path(DPO): chat template ์ง€์ •
  • lora_kwargs: LoraConfig์— ๋„˜๊ธธ ์ถ”๊ฐ€ ์ธ์ž (peft_kwargs๋กœ LoRA ์™ธ PEFT ํƒ€์ž… ํ™•์žฅ๋„ ๊ฐ€๋Šฅ)

ํ•™์Šต/ํ‰๊ฐ€

  • dataset_type(SFT): sft / pretrain / dpo ์„ ํƒ
  • packing / packing_strategy / eval_packing: packing ์ œ์–ด
  • eval_harness_tasks: ํ•™์Šต ์ค‘ lm-eval-harness๋กœ ํ‰๊ฐ€ํ•  ํƒœ์Šคํฌ ๋ชฉ๋ก
  • gpu_mem_check: forward/backward/optimizer ๊ตฌ๊ฐ„๋ณ„ GPU peak memory ๋กœ๊น…

์ฐธ๊ณ ๋กœ wandb๋กœ reportํ•˜๋ฉด ์ฝ”๋“œ ์Šค๋ƒ…์ƒท์„ ์•„ํ‹ฐํŒฉํŠธ๋กœ ๊ฐ™์ด ์ €์žฅํ•ด ์žฌํ˜„์„ฑ์„ ๋‚จ๊ธด๋‹ค.

callbacks

์ฝœ๋ฐฑ์€ src/callbacks.py์— ๋ชจ์—ฌ ์žˆ๊ณ , SFT, DPO๊ฐ€ ๊ณต์šฉํ•œ๋‹ค. ํ•ด๋‹น args๊ฐ€ ์ผœ์งˆ ๋•Œ๋งŒ ๋ถ™๋Š”๋‹ค.

  • EvalHarnessCallBack (eval_harness_tasks): ํ•™์Šต ์ค‘ eval ์‹œ์ ์— lm-eval-harness๋ฅผ ๊ฐ™์ด ๋Œ๋ฆฐ๋‹ค. ์˜ˆ์ „์—” ํ•™์Šต์ด ๋๋‚œ ๋’ค checkpoint๋ฅผ ํ•˜๋‚˜์”ฉ lm eval ๋Œ๋ ธ๋Š”๋ฐ, ๋„ˆ๋ฌด ๋ฒˆ๊ฑฐ๋กญ๊ณ  ์‹œ๊ฐ„๋„ ๋งŽ์ด ๋จน์—ˆ๋‹ค. trainer์— eval/predict loop๊ฐ€ ์žˆ๋Š” ์ด์œ ๊ฐ€ ์ž๋™ํ™”ํ•ด ์‚ฌ๋žŒ ๊ณต์ˆ˜๋ฅผ ์ค„์ด๋Š” ๊ฒƒ์ด๋ผ์„œ lm eval๋„ ๊ทธ ์•ˆ์œผ๋กœ ํฌํ•จ์‹œํ‚ด.
  • WandbCodeArtifactCallback (report_to์— wandb): base code ๋Œ€๋น„ ๋ญ๊ฐ€ ๋ฐ”๋€Œ์—ˆ๊ณ  ์‹คํ—˜์ด ์–ด๋–ป๊ฒŒ ๋˜๋Š”์ง€๋ฅผ artifact๋กœ ๋‚จ๊ธด๋‹ค. ์žฌํ˜„์„ฑ, ๋ฒ„์ „ ๊ด€๋ฆฌ๋ฅผ wandb๋กœ ํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
  • GpuMemoryCallback (gpu_mem_check): ๋‹จ์ˆœ memory ๋””๋ฒ„๊น…์šฉ. ๊ตฌ๊ฐ„๋ณ„ GPU peak memory๋ฅผ ์ฐ๋Š”๋‹ค.

SFT์™€ Pretrain ํ˜ผ์šฉ

SFT ์Šคํฌ๋ฆฝํŠธ ํ•˜๋‚˜๋กœ sft์™€ pretrain์„ ๋‘˜ ๋‹ค ๋Œ๋ฆฐ๋‹ค. dataset_type๋งŒ sft / pretrain์œผ๋กœ ๋ฐ”๊พธ๋ฉด ๋œ๋‹ค.

์ „์ฒ˜๋ฆฌ(assistant-only ๋ผ๋ฒจ๋ง vs ํ†ต๋ฌธ์žฅ ํ† ํฐํ™”)์™€ collator์˜ ๋ผ๋ฒจ ์ฒ˜๋ฆฌ(labels vs input_ids)๊ฐ€ ์ด ๊ฐ’์— ๋”ฐ๋ผ ๊ฐˆ๋ฆด ๋ฟ, packing, ์บ์‹œ, ๋ฐ์ดํ„ฐ ๋งต ๊ฐ™์€ ๋‚˜๋จธ์ง€ ํŒŒ์ดํ”„๋ผ์ธ์€ ๊ทธ๋Œ€๋กœ ๊ณต์œ ํ•œ๋‹ค. ๊ทธ๋ž˜์„œ ๊ฐ™์€ ๋ฐ์ดํ„ฐ ๊ตฌ์„ฑ ์œ„์—์„œ pretrain ํ›„ sft๋กœ ์ด์–ด๊ฐ€๋Š” ์‹์˜ ์‹คํ—˜์„ args๋งŒ ๋ฐ”๊ฟ”์„œ ํ•  ์ˆ˜ ์žˆ๋‹ค.

์บ์‹œ

์ „์ฒ˜๋ฆฌ ์บ์‹œ๋Š” ์ตœ์ดˆ 1ํšŒ ์‹คํ–‰์—์„œ๋งŒ ์ƒ์„ฑ๋œ๋‹ค. ์ตœ์ดˆ 1ํšŒ๋Š” main_process_first ์ปจํ…์ŠคํŠธ๋กœ main process์—์„œ๋งŒ ๋งŒ๋“ค๊ณ , ์ดํ›„ ์‹คํ–‰์€ datasets๊ฐ€ ๊ทธ ์บ์‹œ๋ฅผ ๊ทธ๋Œ€๋กœ ๋กœ๋“œํ•œ๋‹ค.

์บ์‹œ ์œ„์น˜๋Š” ์›๋ณธ ๋ฐ์ดํ„ฐ์™€ ๋ถ„๋ฆฌํ•˜๊ณ , ์ƒํƒœ(๋ชจ๋ธ/repo/run/max_length)๋ณ„๋กœ ๋‚˜๋ˆ  ์ €์žฅํ•œ๋‹ค.

datasets๋Š” arrow ์บ์‹œ๋ฅผ ์›๋ณธ๊ณผ ๊ฐ™์€ ํด๋”์— ์Œ“๋Š”๋ฐ, ์บ์‹œ๋ฅผ ์ •๋ฆฌํ•˜๋‹ค ์›๋ณธ๊นŒ์ง€ ์ง€์šฐ๋Š” ์‚ฌ๊ณ ๊ฐ€ ์žฆ์•˜๋‹ค. ๊ทธ๋ž˜์„œ ๋ถ„๋ฆฌํ–ˆ๋‹ค.

running

git clone https://github.com/jp1924/LLM.git
docker compose up -d # ๋„์ปค ํ™˜๊ฒฝ ๊ตฌ์„ฑ
docker compose stop  # ๋„์ปค ํ™˜๊ฒฝ ์ข…๋ฃŒ

Docker/Compose: 27.3.1-build ce12230, v2.29.7 ํŒจํ‚ค์ง€ ๋ฒ„์ „์ด ๋นจ๋ฆฌ ๋ฐ”๋€Œ์–ด์„œ ์‹คํ–‰ ์•ˆ ๋˜๋ฉด Dockerfile ์ˆ˜์ •์ด ํ•„์š”ํ•  ์ˆ˜ ์žˆ๋‹ค.

ํ•™์Šต ์‹คํ–‰์€ scripts/ ์•„๋ž˜ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์“ด๋‹ค.

bash scripts/run_zero3_sft.sh
bash scripts/run_zero3_dpo.sh

flash-attn ์„ค์น˜

pypi๋กœ ๋ฐ›์œผ๋ฉด ์ปดํŒŒ์ผ์— ์˜ค๋ž˜ ๊ฑธ๋ฆฌ์ง€๋งŒ, flash-attn repo๊ฐ€ ์ฃผ๋Š” whl์€ ์ปดํŒŒ์ผ์ด ๋๋‚˜ ์žˆ์–ด ๋น ๋ฅด๋‹ค. ๋จผ์ € ํ™˜๊ฒฝ ๋ฒ„์ „์„ ํ™•์ธํ•œ๋‹ค.

python -c "import torch; print(torch.__version__, torch.version.cuda, torch.compiled_with_cxx11_abi())"

์ดํ›„ flash-attn repo์—์„œ ํ™˜๊ฒฝ์— ๋งž๋Š” whl์„ ๋ฐ›์•„ ์„ค์น˜ํ•˜๋ฉด ๋œ๋‹ค.

์ž‘์—… ํ™˜๊ฒฝ, ์›Œํฌํ”Œ๋กœ

๋‚ด๊ฐ€ ์ฃผ๋กœ ์ž‘์—…ํ•˜๋Š” ๋ฐฉ์‹. ๊ทธ๋ƒฅ ์ฐธ๊ณ ์šฉ์ž„.

  • docker ํ™˜๊ฒฝ /root/workspace ์•ˆ์—์„œ ์ž‘์—…ํ•œ๋‹ค. ์ดˆ๊ธฐ ํŒŒ์ด์ฌ ์„ธํŒ…์€ uv sync๋กœ ๋งž์ถ˜๋‹ค.
  • ์ฃผ๋กœ VSCode + tmux๋กœ ์ž‘์—…. .vscode์— ์ž์ฃผ ์“ฐ๋Š” vscode extensions๊ณผ debug์— ํ•„์š”ํ•œ ์„ค์ •์€ launch.json์— ์œ„์น˜.
  • tmux, bashrc, vi ๊ฐ™์€ ๊ฐœ๋ฐœํ™˜๊ฒฝ ๊ตฌ์„ฑ ํŒŒ์ผ์€ env ํด๋”์—
  • ํ•™์Šต ๋กœ๊ทธ๋Š” tee๋กœ logs/{method}.log์— ๋‚จ๊ธด๋‹ค.

์‹คํ—˜ ๊ด€๋ฆฌ๋Š” wandb๋กœ ํ•œ๋‹ค. ๊ทธ๋ž˜์„œ config/{method}.yaml์„ ์‹คํ—˜๋งˆ๋‹ค ๋ณต์‚ฌํ•˜๊ฑฐ๋‚˜ ์ด๋ฆ„์„ ๋ฐ”๊ฟ” ์ฐจ์ด๋ฅผ ์„ค๋ช…ํ•˜์ง€ ์•Š๊ณ , WANDB_NOTES env๋กœ ์‹คํ—˜๋ณ„ ํŠน์ง•์„ ์ ๋Š” ๊ฑธ ์„ ํ˜ธํ•œ๋‹ค. ๋กœ๊ทธ๋„ {method}.log ํ•˜๋‚˜๋กœ ๋‘”๋‹ค. ๋ณต์‚ฌ๋ณธ์„ ๋งŒ๋“ค๋ฉด ๋‚˜์ค‘์— diff, ๋ฒ„์ „ ๊ด€๋ฆฌ๊ฐ€ ์–ด๋ ค์›Œ์ง€๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

About

๐Ÿค—LLM, LMM ํ•™์Šต์‹œํ‚ค๊ธฐ ์œ„ํ•œ ์ฝ”๋“œ

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages