Encoder vectorization + CLI flags for training on larger corpora#2
Open
modulovalue wants to merge 2 commits into
Open
Encoder vectorization + CLI flags for training on larger corpora#2modulovalue wants to merge 2 commits into
modulovalue wants to merge 2 commits into
Conversation
The original Encoder had two pure-Python passes: it built the vocab dict by iterating every character and dispatching a dict membership test, and its encode() ran a per-character dict lookup. Both are O(N) in Python. On Shakespeare (~1M chars) that is fine. On larger corpora (e.g. a 500MB wiki dump, ~520M chars) those Python loops dominate startup. This commit: - Builds the vocabulary as `sorted(set(text))` (one pass in C). - Adds `encode_array(text)` which converts text to UTF-32-LE bytes via the C codec, views the buffer as `uint32` codepoints, and gathers through a precomputed lookup table indexed by ord(c). Output is a `np.ndarray[int32]` ready to be moved to a torch tensor. - Caches the inverse dict so decode() does not rebuild it on every call. - Bulk encoding is chunked (default 4M chars/block) so peak transient memory stays bounded for very large corpora. Existing public API is preserved: `encode`, `decode`, and `vocab` behave the same. `encode_array` is additive. Measured on a 500 MB wiki dump (522M characters): vocab build: ~2.6 s (was estimated at minutes) encode_array: ~0.8 s (was estimated at minutes)
Adds a small set of CLI knobs needed to point training at a different corpus and to recover from interruptions, plus a few correctness/perf tweaks that come along for the ride: - `Transformer(data_path=...)` constructor argument; previously the path was hardcoded to "data/input.txt". Threaded through train.py, sample.py, and export_onnx.py via a `--data` flag (default unchanged). - `--batch-size` and `--seq-len` flags so hyperparameters can be tuned without editing source. - `--resume <checkpoint.pt>` flag that loads a saved state_dict before training. Useful for picking up a long run after a crash, machine reboot, or any other interruption. Only the model weights are restored; the optimizer state and step counter are not. - Use the new `encoder.encode_array()` and store the corpus as `int32` on device. The vocabulary easily fits in 32 bits (this PR's wiki sample has ~720 chars, the full HF wiki dump has ~7000), so int64 was wasting 50% of corpus memory. On a 500MB wiki corpus this saves ~2 GB of device memory. - Read the corpus with `f.read()` instead of `"\n".join(f.readlines())`. The old form silently doubled every newline. No vocab change, the encoder was building from the same join'd text. Sanity-checked: training on tiny Shakespeare with default flags gives the same it/s and matching loss curve as before.
681053b to
84b8639
Compare
Owner
|
Looks good. One thing: If you want to resume from a previous checkpoint, you also want to make sure you're resuming the optimizer state imo. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Two commits that make it practical to point this codebase at corpora bigger and more diverse than Shakespeare. Tested locally on a 500 MB Wikipedia dump (522M chars, 7003-char vocab); steady at 27 it/s on M-series MPS using the same model.
Changes
1. Vectorize Encoder vocab build and bulk encode with numpy
The original
Encoderran two pure-Python passes: a per-character dict-membership check to build the vocab, and a per-character lookup to encode. Fine on Shakespeare, but on a 500 MB wiki dump those Python loops dominate startup (estimated minutes each). New code:set(text)for vocab (one C pass)encode_array(text): encodes text to UTF-32-LE, views asuint32codepoints, gathers through a precomputed lookup table -- returnsnp.ndarray[int32]decode()doesn't rebuild it on every callencode,decode,vocabare unchanged.encode_arrayis additive.Measured on a 500 MB wiki dump (522M characters):
2. Make dataset path configurable and add training CLI flags
Transformer(data_path=...)constructor argument; threaded through train.py, sample.py, export_onnx.py via a--dataflag.--batch-sizeand--seq-lenflags so hyperparameters can be tuned without editing source.--resume <checkpoint.pt>flag that loads a saved state_dict before training -- useful for picking up long runs after a crash. Only model weights are restored; optimizer state and step counter are not.encoder.encode_array()and store the corpus asint32on device. Vocab fits easily in 32 bits, so int64 was wasting 50% of corpus memory. On a 500 MB corpus this saves ~2 GB of device memory.f.read()instead of\"\\n\".join(f.readlines())-- the old form silently doubled every newline (no vocab change, since the encoder was building from the same join'd text).Defaults are unchanged, so
uv run train --device mpsbehaves exactly the same as before. Sanity-checked on Shakespeare: same it/s, matching loss curve.Test plan
uv run train --device mps(defaults to Shakespeare) -- identical it/s and loss curveuv run train --device mps --data <other.txt>-- runs cleanly on multilingual corpora--resume checkpoints/checkpoint.ptloads a previous state_dict and continues training