Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions docs/source/advanced/packing.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ and {py:meth}`pack_selected_samples <megatron.energon.TaskEncoder.pack_selected_
Furthermore, you need to initialize the loader with the `packing_buffer_size` argument set to a non-zero number.

The `select_samples_to_pack` method will receive a list of samples (size according to the selected `packing_buffer_size`),
and should partition those samples into groups that shall be packed together. Hence the function returns
a list of lists of samples.
and should partition those samples into groups that shall be packed together. Hence the function typically returns
a list of lists of samples. Alternatively it may return {py:class}`PackedSamplesOutput <megatron.energon.PackedSamplesOutput>` with a ``pushback`` sequence: those samples are appended back to the reading buffer before the next fill from the dataset.

For each group, the second method `pack_selected_samples` will be called. You need to implement how a group of
samples will be mapped to a single sample. In terms of LLMs for example, this method might concatenate the input tokens.
Expand All @@ -42,6 +42,83 @@ samples for packing. If augmentations happen, it should be marked with
You have to make sure the methods are actually stateless, meaning that they will produce the same output when invoked
with the same input and random states.

## Carrying over partial samples

Sometimes the next sample only partially fits into the remaining packed context. In that case,
{py:meth}`select_samples_to_pack <megatron.energon.TaskEncoder.select_samples_to_pack>` can return a
{py:class}`PartialSample <megatron.energon.PartialSample>` for the part that fits and push back another
`PartialSample` for the remainder.

`PartialSample` stores the original sample and a task-defined slice payload:

```python
@edataclass
class PartialSample(Generic[T_sample, T_slice]):
sample: T_sample
slice: T_slice
```

The `slice` object is stored as-is in loader state and restore keys, so it must be serializable by
the same mechanism you use for checkpointing loader state. A `tuple[int, int]` using normal Python
`(start, stop)` slicing semantics is a typical choice.

The slicing semantics are user-defined. Energon preserves and restores the `slice` payload, but your
task encoder applies it. If you override
{py:meth}`postencode_sample <megatron.energon.TaskEncoder.postencode_sample>` and produce partials,
`postencode_sample` must accept both full samples and `PartialSample` inputs:

```python
def select_samples_to_pack(
self,
samples: list[TokenizedSample | PartialSample[TokenizedSample, tuple[int, int]]],
) -> PackedSamplesOutput[TokenizedSample | PartialSample[TokenizedSample, tuple[int, int]]]:
sample = samples[0]
if isinstance(sample, PartialSample):
base_sample = sample.sample
token_start, token_stop = sample.slice
else:
base_sample = sample
token_start = 0
token_stop = len(sample.tokens)

return PackedSamplesOutput(
packs=[
[
PartialSample(
sample=base_sample,
slice=(token_start, token_start + self.remaining_context),
)
]
],
pushback=(
PartialSample(
sample=base_sample,
slice=(token_start + self.remaining_context, token_stop),
),
),
)


@stateless
def postencode_sample(
self,
sample: TokenizedSample | PartialSample[TokenizedSample, tuple[int, int]],
) -> TokenizedSample:
if isinstance(sample, PartialSample):
token_start, token_stop = sample.slice
sample = sample.sample
return TokenizedSample.derive_from(
sample,
tokens=sample.tokens[token_start:token_stop],
)
return sample
```

If you do not use `postencode_sample`, then
{py:meth}`pack_selected_samples <megatron.energon.TaskEncoder.pack_selected_samples>` receives the
`PartialSample` values directly and must apply the slice there. In that mode, type the final packer
to accept the same union of full and partial samples returned by `select_samples_to_pack`.

Example packing for a large language model extending the example from the [](../basic/task_encoder) section:

```python
Expand Down
4 changes: 2 additions & 2 deletions docs/source/basic/task_encoder.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ and override one or more of the following methods. The data flow of {py:func}`ge
- {py:meth}`def preencode_sample(self, sample: T_sample) -> T_sample <megatron.energon.TaskEncoder.preencode_sample>`
- Can be used as a replacement to {py:func}`encode_sample <megatron.energon.DefaultTaskEncoder.encode_sample>`
- Only used together with {py:meth}`postencode_sample <megatron.energon.TaskEncoder.postencode_sample>` (below). Use it if you need to instantiate lazy data before packing the selected samples. But it will also work when not using packing.
- {py:meth}`def select_samples_to_pack(self, samples: List[T_encoded_sample]) -> List[List[T_encoded_sample]] <megatron.energon.TaskEncoder.select_samples_to_pack>`
- Optional. Allows for efficient sample packing. See [](../advanced/packing).
- {py:meth}`def select_samples_to_pack(self, samples: List[T_encoded_sample]) -> Union[List[List[T_encoded_sample]], PackedSamplesOutput] <megatron.energon.TaskEncoder.select_samples_to_pack>`
- Optional. Allows for efficient sample packing. See [](../advanced/packing). May return {py:class}`PackedSamplesOutput <megatron.energon.PackedSamplesOutput>` to re-queue a pushback sequence onto the reading buffer.
- {py:meth}`def postencode_sample(self, sample: T_sample) -> T_encoded_sample <megatron.energon.TaskEncoder.postencode_sample>`
- Only used together with {py:meth}`preencode_sample <megatron.energon.TaskEncoder.preencode_sample>`. Use it if you need to instantiate lazy data before packing the selected samples. But it will also work when not using packing.
- {py:meth}`def pack_selected_samples(self, samples: List[T_encoded_sample]) -> T_batch_sample] <megatron.energon.TaskEncoder.pack_selected_samples>`
Expand Down
4 changes: 4 additions & 0 deletions src/megatron/energon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@
LogSampleDataset,
MapDataset,
MixBatchDataset,
PackedSamplesOutput,
PackingDataset,
PartialSample,
RepeatDataset,
ShuffleBufferDataset,
concat_pad,
Expand Down Expand Up @@ -182,10 +184,12 @@
"NoCachePool",
"OCRSample",
"OCRWebdataset",
"PartialSample",
"PackingDataset",
"PrimaryFileStore",
"prepare_metadataset",
"RepeatDataset",
"PackedSamplesOutput",
"reraise_exception",
"Sample",
"SampleDecoder",
Expand Down
2 changes: 2 additions & 0 deletions src/megatron/energon/metadataset/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def get_datasets(
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
group: Optional[str] = None,
**kwargs,
) -> LoadedDatasetList:
return LoadedDatasetList(
Expand All @@ -106,6 +107,7 @@ def get_datasets(
**kwargs,
),
weight=None,
group=group,
)
],
)
2 changes: 2 additions & 0 deletions src/megatron/energon/metadataset/join_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def get_datasets(
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
group: Optional[str] = None,
**kwargs,
) -> LoadedDatasetList:
return LoadedDatasetList(
Expand All @@ -557,6 +558,7 @@ def get_datasets(
**kwargs,
),
weight=None,
group=group,
)
],
)
28 changes: 28 additions & 0 deletions src/megatron/energon/metadataset/loader_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,17 @@ class DatasetBlendMode(Enum):

@edataclass
class LoadedDataset:
#: The dataset factory.
dataset: BaseCoreDatasetFactory
#: Sampling weight when using dataset-weight blending.
weight: Union[float, int, None] = None
#: Epochized repetition count when using repetition-based blending.
repetitions: Union[float, int, None] = None
#: Dataset group key from Metadataset V2 (YAML ``group`` field). Must match keys used in
#: ``packing_buffer_size`` / ``shuffle_buffer_size`` when those are dicts. ``None`` is the
#: default group.
group: Optional[str] = None
#: Auxiliary datasets for crude cooking.
aux: Optional[Dict[str, FileStore]] = None


Expand All @@ -48,12 +56,18 @@ class TraversedDatasetReference:
split_part: Effective split part to use when loading the leaf dataset.
aux: Resolved auxiliary dataset or filesystem references keyed by auxiliary name.
subflavors: Effective subflavors implied by the traversed metadataset hierarchy.
group: Merged dataset group name from Metadataset V2 ``group`` fields along the path to this
leaf (nested segments join with ``+``). Keys dict entries for per-group
``packing_buffer_size`` and ``shuffle_buffer_size``.
shuffle_over_epochs_multiplier: Effective shuffle over epochs multiplier from metadataset references.
"""

path: EPath
split_part: str
aux: dict[str, EPath]
subflavors: dict[str, Any]
group: Optional[str] = None
shuffle_over_epochs_multiplier: Optional[int] = 1


class DatasetLoaderInterface(ABC):
Expand All @@ -69,6 +83,8 @@ def traverse(
mds_path: Optional[EPath] = None,
*,
split_part: Union[Literal["train", "val", "test"], str],
_group: Optional[str] = None,
_shuffle_over_epochs_multiplier: Optional[int] = 1,
_subflavors: Optional[Dict[str, Any]] = None,
) -> List[TraversedDatasetReference]:
"""Traverse a metadataset subtree and collect flattened leaf dataset references.
Expand All @@ -83,6 +99,12 @@ def traverse(
use None only for top-level metadatasets.
split_part: Split to traverse, such as `\"train\"`, `\"val\"`, or `\"test\"`. Nested
references may override this with their own configured split.
_group: Inherited merged group name from Metadataset V2 ``group`` fields (nested segments
join with ``+``). Used to match dict keys in ``packing_buffer_size`` /
``shuffle_buffer_size``.
_shuffle_over_epochs_multiplier: Inherited shuffle multiplier (merged per node like
``get_datasets``); default ``1``.
_subflavors: Effective subflavors implied by the traversed metadataset hierarchy.

Returns:
A flattened list of `TraversedDatasetReference` values for all leaf datasets reached
Expand All @@ -100,6 +122,7 @@ def get_datasets(
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
group: Optional[str] = None,
**kwargs,
) -> LoadedDatasetList:
"""
Expand All @@ -120,6 +143,11 @@ def get_datasets(
an infinite number of epochs (effectively, this will draw shard slices with
replacement).
subset: If specified, the inner dataset(s) will be subsetted.
group: Dataset group for this leaf (merged with outer scopes). Datasets that share the same
non-``None`` key are blended and shuffled together, then each group's stream runs through
packing (:class:`~megatron.energon.PackingDataset`) before streams from different groups
are blended. Must match dict keys in ``packing_buffer_size`` / ``shuffle_buffer_size`` when
those are mappings. ``None`` selects the default group.
**kwargs: Additional arguments to the dataset constructor.

Returns:
Expand Down
Loading
Loading