diff --git a/docs/source/advanced/packing.md b/docs/source/advanced/packing.md index 45df53ac..f68d1cbc 100644 --- a/docs/source/advanced/packing.md +++ b/docs/source/advanced/packing.md @@ -18,8 +18,8 @@ and {py:meth}`pack_selected_samples ` with a ``pushback`` sequence: those samples are appended back to the reading buffer before the next fill from the dataset. For each group, the second method `pack_selected_samples` will be called. You need to implement how a group of samples will be mapped to a single sample. In terms of LLMs for example, this method might concatenate the input tokens. @@ -42,6 +42,83 @@ samples for packing. If augmentations happen, it should be marked with You have to make sure the methods are actually stateless, meaning that they will produce the same output when invoked with the same input and random states. +## Carrying over partial samples + +Sometimes the next sample only partially fits into the remaining packed context. In that case, +{py:meth}`select_samples_to_pack ` can return a +{py:class}`PartialSample ` for the part that fits and push back another +`PartialSample` for the remainder. + +`PartialSample` stores the original sample and a task-defined slice payload: + +```python +@edataclass +class PartialSample(Generic[T_sample, T_slice]): + sample: T_sample + slice: T_slice +``` + +The `slice` object is stored as-is in loader state and restore keys, so it must be serializable by +the same mechanism you use for checkpointing loader state. A `tuple[int, int]` using normal Python +`(start, stop)` slicing semantics is a typical choice. + +The slicing semantics are user-defined. Energon preserves and restores the `slice` payload, but your +task encoder applies it. If you override +{py:meth}`postencode_sample ` and produce partials, +`postencode_sample` must accept both full samples and `PartialSample` inputs: + +```python +def select_samples_to_pack( + self, + samples: list[TokenizedSample | PartialSample[TokenizedSample, tuple[int, int]]], +) -> PackedSamplesOutput[TokenizedSample | PartialSample[TokenizedSample, tuple[int, int]]]: + sample = samples[0] + if isinstance(sample, PartialSample): + base_sample = sample.sample + token_start, token_stop = sample.slice + else: + base_sample = sample + token_start = 0 + token_stop = len(sample.tokens) + + return PackedSamplesOutput( + packs=[ + [ + PartialSample( + sample=base_sample, + slice=(token_start, token_start + self.remaining_context), + ) + ] + ], + pushback=( + PartialSample( + sample=base_sample, + slice=(token_start + self.remaining_context, token_stop), + ), + ), + ) + + +@stateless +def postencode_sample( + self, + sample: TokenizedSample | PartialSample[TokenizedSample, tuple[int, int]], +) -> TokenizedSample: + if isinstance(sample, PartialSample): + token_start, token_stop = sample.slice + sample = sample.sample + return TokenizedSample.derive_from( + sample, + tokens=sample.tokens[token_start:token_stop], + ) + return sample +``` + +If you do not use `postencode_sample`, then +{py:meth}`pack_selected_samples ` receives the +`PartialSample` values directly and must apply the slice there. In that mode, type the final packer +to accept the same union of full and partial samples returned by `select_samples_to_pack`. + Example packing for a large language model extending the example from the [](../basic/task_encoder) section: ```python diff --git a/docs/source/basic/task_encoder.md b/docs/source/basic/task_encoder.md index 37295b0b..076a2709 100644 --- a/docs/source/basic/task_encoder.md +++ b/docs/source/basic/task_encoder.md @@ -19,8 +19,8 @@ and override one or more of the following methods. The data flow of {py:func}`ge - {py:meth}`def preencode_sample(self, sample: T_sample) -> T_sample ` - Can be used as a replacement to {py:func}`encode_sample ` - Only used together with {py:meth}`postencode_sample ` (below). Use it if you need to instantiate lazy data before packing the selected samples. But it will also work when not using packing. -- {py:meth}`def select_samples_to_pack(self, samples: List[T_encoded_sample]) -> List[List[T_encoded_sample]] ` - - Optional. Allows for efficient sample packing. See [](../advanced/packing). +- {py:meth}`def select_samples_to_pack(self, samples: List[T_encoded_sample]) -> Union[List[List[T_encoded_sample]], PackedSamplesOutput] ` + - Optional. Allows for efficient sample packing. See [](../advanced/packing). May return {py:class}`PackedSamplesOutput ` to re-queue a pushback sequence onto the reading buffer. - {py:meth}`def postencode_sample(self, sample: T_sample) -> T_encoded_sample ` - Only used together with {py:meth}`preencode_sample `. Use it if you need to instantiate lazy data before packing the selected samples. But it will also work when not using packing. - {py:meth}`def pack_selected_samples(self, samples: List[T_encoded_sample]) -> T_batch_sample] ` diff --git a/src/megatron/energon/__init__.py b/src/megatron/energon/__init__.py index e5204230..dc01739d 100644 --- a/src/megatron/energon/__init__.py +++ b/src/megatron/energon/__init__.py @@ -104,7 +104,9 @@ LogSampleDataset, MapDataset, MixBatchDataset, + PackedSamplesOutput, PackingDataset, + PartialSample, RepeatDataset, ShuffleBufferDataset, concat_pad, @@ -182,10 +184,12 @@ "NoCachePool", "OCRSample", "OCRWebdataset", + "PartialSample", "PackingDataset", "PrimaryFileStore", "prepare_metadataset", "RepeatDataset", + "PackedSamplesOutput", "reraise_exception", "Sample", "SampleDecoder", diff --git a/src/megatron/energon/metadataset/dataset_loader.py b/src/megatron/energon/metadataset/dataset_loader.py index 46e93588..b115ac9c 100644 --- a/src/megatron/energon/metadataset/dataset_loader.py +++ b/src/megatron/energon/metadataset/dataset_loader.py @@ -90,6 +90,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: return LoadedDatasetList( @@ -106,6 +107,7 @@ def get_datasets( **kwargs, ), weight=None, + group=group, ) ], ) diff --git a/src/megatron/energon/metadataset/join_dataset_loader.py b/src/megatron/energon/metadataset/join_dataset_loader.py index 4ea1cbd4..7e00791b 100644 --- a/src/megatron/energon/metadataset/join_dataset_loader.py +++ b/src/megatron/energon/metadataset/join_dataset_loader.py @@ -541,6 +541,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: return LoadedDatasetList( @@ -557,6 +558,7 @@ def get_datasets( **kwargs, ), weight=None, + group=group, ) ], ) diff --git a/src/megatron/energon/metadataset/loader_interface.py b/src/megatron/energon/metadataset/loader_interface.py index ba7c6ea5..28779e26 100644 --- a/src/megatron/energon/metadataset/loader_interface.py +++ b/src/megatron/energon/metadataset/loader_interface.py @@ -27,9 +27,17 @@ class DatasetBlendMode(Enum): @edataclass class LoadedDataset: + #: The dataset factory. dataset: BaseCoreDatasetFactory + #: Sampling weight when using dataset-weight blending. weight: Union[float, int, None] = None + #: Epochized repetition count when using repetition-based blending. repetitions: Union[float, int, None] = None + #: Dataset group key from Metadataset V2 (YAML ``group`` field). Must match keys used in + #: ``packing_buffer_size`` / ``shuffle_buffer_size`` when those are dicts. ``None`` is the + #: default group. + group: Optional[str] = None + #: Auxiliary datasets for crude cooking. aux: Optional[Dict[str, FileStore]] = None @@ -48,12 +56,18 @@ class TraversedDatasetReference: split_part: Effective split part to use when loading the leaf dataset. aux: Resolved auxiliary dataset or filesystem references keyed by auxiliary name. subflavors: Effective subflavors implied by the traversed metadataset hierarchy. + group: Merged dataset group name from Metadataset V2 ``group`` fields along the path to this + leaf (nested segments join with ``+``). Keys dict entries for per-group + ``packing_buffer_size`` and ``shuffle_buffer_size``. + shuffle_over_epochs_multiplier: Effective shuffle over epochs multiplier from metadataset references. """ path: EPath split_part: str aux: dict[str, EPath] subflavors: dict[str, Any] + group: Optional[str] = None + shuffle_over_epochs_multiplier: Optional[int] = 1 class DatasetLoaderInterface(ABC): @@ -69,6 +83,8 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: """Traverse a metadataset subtree and collect flattened leaf dataset references. @@ -83,6 +99,12 @@ def traverse( use None only for top-level metadatasets. split_part: Split to traverse, such as `\"train\"`, `\"val\"`, or `\"test\"`. Nested references may override this with their own configured split. + _group: Inherited merged group name from Metadataset V2 ``group`` fields (nested segments + join with ``+``). Used to match dict keys in ``packing_buffer_size`` / + ``shuffle_buffer_size``. + _shuffle_over_epochs_multiplier: Inherited shuffle multiplier (merged per node like + ``get_datasets``); default ``1``. + _subflavors: Effective subflavors implied by the traversed metadataset hierarchy. Returns: A flattened list of `TraversedDatasetReference` values for all leaf datasets reached @@ -100,6 +122,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: """ @@ -120,6 +143,11 @@ def get_datasets( an infinite number of epochs (effectively, this will draw shard slices with replacement). subset: If specified, the inner dataset(s) will be subsetted. + group: Dataset group for this leaf (merged with outer scopes). Datasets that share the same + non-``None`` key are blended and shuffled together, then each group's stream runs through + packing (:class:`~megatron.energon.PackingDataset`) before streams from different groups + are blended. Must match dict keys in ``packing_buffer_size`` / ``shuffle_buffer_size`` when + those are mappings. ``None`` selects the default group. **kwargs: Additional arguments to the dataset constructor. Returns: diff --git a/src/megatron/energon/metadataset/metadataset.py b/src/megatron/energon/metadataset/metadataset.py index ec8efa45..8ce475aa 100644 --- a/src/megatron/energon/metadataset/metadataset.py +++ b/src/megatron/energon/metadataset/metadataset.py @@ -81,6 +81,22 @@ def _merge_traversed_subflavors( return {**self.subflavors, **(inherited_subflavors or {})} return dict(inherited_subflavors or {}) + def _merge_shuffle_over_epochs_multiplier( + self, inherited_shuffle_over_epochs_multiplier: Optional[int] + ) -> Optional[int]: + """Same semantics as Metadataset V2 ``ShuffleOverEpochsMultiplierMixin`` / ``get_datasets``.""" + if ( + inherited_shuffle_over_epochs_multiplier is None + or self.shuffle_over_epochs_multiplier is None + ): + return None + if ( + inherited_shuffle_over_epochs_multiplier == -1 + or self.shuffle_over_epochs_multiplier == -1 + ): + return -1 + return inherited_shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier + def post_initialize(self, mds_path: Optional[EPath] = None): self._resolve_path(mds_path) if self.path.is_file(): @@ -101,35 +117,30 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: - """Traverse this V1 dataset reference into flattened leaf references. - - Args: - mds_path: Parent metadataset path used internally to resolve relative dataset and - auxiliary paths. Must be set for nested references and inner traversal nodes; - use None only for top-level metadatasets. - split_part: Split inherited from the parent traversal. If this reference defines its - own split override, that split takes precedence for nested traversal and the - returned leaf reference. - - Returns: - A single leaf `TraversedDatasetReference` for direct dataset references, or the - flattened traversal result of the nested metadataset when this reference points to one. - """ self._resolve_path(mds_path) - effective_subflavors = self._merge_traversed_subflavors(_subflavors) + _subflavors = self._merge_traversed_subflavors(_subflavors) + _shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + _shuffle_over_epochs_multiplier + ) if self.path.is_file(): return self._load_nested_metadataset().traverse( split_part=self.split_part or split_part, - _subflavors=effective_subflavors, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, + _subflavors=_subflavors, ) return [ TraversedDatasetReference( path=self.path, split_part=self.split_part or split_part, aux={}, - subflavors=effective_subflavors, + subflavors=_subflavors, + group=_group, + shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, ) ] @@ -142,6 +153,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: if self.subflavors is not None: @@ -167,6 +179,7 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=new_shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) @@ -187,6 +200,8 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: assert mds_path is not None @@ -196,6 +211,8 @@ def traverse( dataset.traverse( mds_path, split_part=split_part, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, _subflavors=_subflavors, ) ) @@ -210,6 +227,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: sum_weight = sum(dataset.weight for dataset in self.datasets) @@ -222,6 +240,7 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) if inner_result.blend_mode not in ( @@ -270,21 +289,16 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: - """Traverse the selected V1 split and flatten all reachable leaf references. - - Args: - mds_path: Unused for top-level metadatasets. Present to satisfy the shared interface. - split_part: Split to traverse. - - Returns: - The flattened list of traversed leaf dataset references for `split_part`. - """ assert mds_path is None return self._splits[split_part].traverse( self._path, split_part=split_part, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, _subflavors=_subflavors, ) @@ -297,6 +311,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: return self._splits[split_part].get_datasets( @@ -306,5 +321,6 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) diff --git a/src/megatron/energon/metadataset/metadataset_v2.py b/src/megatron/energon/metadataset/metadataset_v2.py index 6f277e95..5f323b1d 100644 --- a/src/megatron/energon/metadataset/metadataset_v2.py +++ b/src/megatron/energon/metadataset/metadataset_v2.py @@ -173,7 +173,7 @@ def merge(self, parent_subset: DatasetSubset | None) -> DatasetSubset: ) -@edataclass +@dataclass(kw_only=True, eq=False) class SubsetRatioMixin: subset: Optional[Subset] = None @@ -191,13 +191,80 @@ def _get_subset(self, parent_subset: Optional[DatasetSubset]) -> Optional[Datase return None +@dataclass(kw_only=True, eq=False) +class GroupMixin: + #: Names a dataset group for train pipelines (blend/shuffle/pack): datasets sharing the + #: same non-None string are blended and shuffled together before packing; packed streams from + #: different groups are then blended. ``None`` means default from outer scopes. + group: Optional[str] = None + + def _merge_group(self, inherited_group: Optional[str]) -> Optional[str]: + if self.group is not None: + if inherited_group is not None: + return f"{inherited_group}+{self.group}" + return self.group + return inherited_group + + +@dataclass(kw_only=True, eq=False) +class ShuffleOverEpochsMultiplierMixin: + shuffle_over_epochs_multiplier: Optional[int] = 1 + + def _merge_shuffle_over_epochs_multiplier( + self, inherited_shuffle_over_epochs_multiplier: Optional[int] + ) -> Optional[int]: + if ( + inherited_shuffle_over_epochs_multiplier is None + or self.shuffle_over_epochs_multiplier is None + ): + # If no shuffling is requested, this has override priority. + return None + elif ( + inherited_shuffle_over_epochs_multiplier == -1 + or self.shuffle_over_epochs_multiplier == -1 + ): + # Next priority is sampling without replacement. + return -1 + else: + # Otherwise, multiply the shuffle over epochs multiplier. + return inherited_shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier + + +@dataclass(kw_only=True, eq=False) +class SubflavorsMixin: + subflavors: Optional[Dict[str, Any]] = None + + def _merge_subflavors(self, inherited_subflavors: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """Merge this reference's subflavors with the inherited traversal subflavors. + + The merge order mirrors `get_datasets(...)`: this reference contributes the base mapping, + and inherited outer-hierarchy subflavors override on key conflicts. + + + Args: + inherited_subflavors: Effective subflavors accumulated from outer metadataset + references during traversal. + + Returns: + The effective subflavor mapping for this reference, after applying outer-overrides-inner + merge semantics. + """ + if self.subflavors is not None: + return {**self.subflavors, **(inherited_subflavors or {})} + return dict(inherited_subflavors or {}) + + @edataclass -class DatasetReference(SubsetRatioMixin, DatasetLoaderInterface): +class DatasetReference( + SubsetRatioMixin, + GroupMixin, + ShuffleOverEpochsMultiplierMixin, + SubflavorsMixin, + DatasetLoaderInterface, +): path: Union[str, EPath] split_part: Optional[str] = None - subflavors: Optional[Dict[str, Any]] = None - shuffle_over_epochs_multiplier: Optional[int] = 1 dataset_config: Optional[str] = None split_config: Optional[str] = None @@ -269,26 +336,6 @@ def _get_traversed_aux_references(self) -> dict[str, EPath]: traversed_aux[key] = value.fs_path return traversed_aux - def _merge_traversed_subflavors( - self, inherited_subflavors: Optional[Dict[str, Any]] - ) -> Dict[str, Any]: - """Merge this reference's subflavors with the inherited traversal subflavors. - - The merge order mirrors `get_datasets(...)`: this reference contributes the base mapping, - and inherited outer-hierarchy subflavors override on key conflicts. - - Args: - inherited_subflavors: Effective subflavors accumulated from outer metadataset - references during traversal. - - Returns: - The effective subflavor mapping for this reference, after applying outer-overrides-inner - merge semantics. - """ - if self.subflavors is not None: - return {**self.subflavors, **(inherited_subflavors or {})} - return dict(inherited_subflavors or {}) - def _load_nested_metadataset(self) -> DatasetLoaderInterface: assert isinstance(self.path, EPath) assert self.aux is None, "Cannot specify auxiliary datasets for crude datasets" @@ -327,33 +374,24 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: - """Traverse this V2 dataset reference into flattened leaf references. - - For direct leaf datasets, traversal resolves the dataset path and any auxiliary references - into plain `EPath` values. For nested metadatasets, traversal recurses immediately into the - referenced split instead of building an intermediate object graph. - Args: - mds_path: Parent metadataset path used internally to resolve relative dataset and - auxiliary paths. Must be set for nested references and inner traversal nodes; - use None only for top-level metadatasets. - split_part: Split inherited from the parent traversal. If this reference defines its - own split override, that split takes precedence for nested traversal and the - returned leaf reference. - - Returns: - A single leaf `TraversedDatasetReference` for direct dataset references, or the - flattened traversal result of the nested metadataset when this reference points to one. - """ self._resolve_path(mds_path) - effective_subflavors = self._merge_traversed_subflavors(_subflavors) + _subflavors = self._merge_subflavors(_subflavors) + _group = self._merge_group(_group) + _shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + _shuffle_over_epochs_multiplier + ) ds_type = get_dataset_type(self.path) if ds_type == EnergonDatasetType.METADATASET: return self._load_nested_metadataset().traverse( split_part=self.split_part or split_part, - _subflavors=effective_subflavors, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, + _subflavors=_subflavors, ) self._normalize_aux_references(mds_path, validate=False) return [ @@ -361,7 +399,9 @@ def traverse( path=self.path, split_part=self.split_part or split_part, aux=self._get_traversed_aux_references(), - subflavors=effective_subflavors, + subflavors=_subflavors, + group=_group, + shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, ) ] @@ -378,32 +418,21 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: - if self.subflavors is not None: - subflavors = {**self.subflavors, **(subflavors or {})} assert self._dataset is not None - if shuffle_over_epochs_multiplier is None or self.shuffle_over_epochs_multiplier is None: - # If no shuffling is requested, this has override priority. - new_shuffle_over_epochs_multiplier = None - elif shuffle_over_epochs_multiplier == -1 or self.shuffle_over_epochs_multiplier == -1: - # Next priority is sampling without replacement. - new_shuffle_over_epochs_multiplier = -1 - else: - # Otherwise, multiply the shuffle over epochs multiplier. - new_shuffle_over_epochs_multiplier = ( - shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier - ) - subset = self._get_subset(subset) - result = self._dataset.get_datasets( training=training, split_part=self.split_part or split_part, worker_config=worker_config, - subflavors=subflavors, - shuffle_over_epochs_multiplier=new_shuffle_over_epochs_multiplier, - subset=subset, + subflavors=self._merge_subflavors(subflavors), + shuffle_over_epochs_multiplier=self._merge_shuffle_over_epochs_multiplier( + shuffle_over_epochs_multiplier + ), + subset=self._get_subset(subset), + group=self._merge_group(group), **kwargs, ) if self.aux is not None: @@ -443,6 +472,8 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: raise NotImplementedError("traverse_metadataset() does not support joined datasets.") @@ -462,13 +493,17 @@ def get_datasets( @edataclass -class MetadatasetJoin(SubsetRatioMixin, DatasetLoaderInterface): +class MetadatasetJoin( + SubsetRatioMixin, + GroupMixin, + ShuffleOverEpochsMultiplierMixin, + SubflavorsMixin, + DatasetLoaderInterface, +): join: Union[List[JoinDatasetReference], Dict[str, JoinDatasetReference]] joiner: Union[Type[Sample], Callable[..., Sample]] split_part: Optional[str] = None - subflavors: Optional[Dict[str, Any]] = None - shuffle_over_epochs_multiplier: Optional[int] = 1 dataset_config: Optional[str] = None split_config: Optional[str] = None @@ -514,6 +549,8 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: raise NotImplementedError("traverse_metadataset() does not support joined datasets.") @@ -531,17 +568,20 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: assert self._dataset is not None, "Missing post_initialize call." - subset = self._get_subset(subset) return self._dataset.get_datasets( training=training, split_part=split_part, worker_config=worker_config, - subflavors=subflavors, - shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, - subset=subset, + subflavors=self._merge_subflavors(subflavors), + shuffle_over_epochs_multiplier=self._merge_shuffle_over_epochs_multiplier( + shuffle_over_epochs_multiplier + ), + subset=self._get_subset(subset), + group=self._merge_group(group), **kwargs, ) @@ -562,10 +602,16 @@ class BlendJoinDatasetReference(BlendWeightMixin, MetadatasetJoin): @edataclass -class MetadatasetBlend(DatasetLoaderInterface, SubsetRatioMixin): +class MetadatasetBlend( + SubsetRatioMixin, + GroupMixin, + ShuffleOverEpochsMultiplierMixin, + SubflavorsMixin, + DatasetLoaderInterface, +): """Blending of datasets by specifying the sampling weight for the inner datasets.""" - blend: List[Union[BlendDatasetReference, BlendJoinDatasetReference]] + blend: List[Union[BlendDatasetReference, BlendJoinDatasetReference, "MetadatasetBlend"]] def post_initialize(self, mds_path: Optional[EPath] = None): assert mds_path is not None @@ -577,15 +623,24 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: assert mds_path is not None + _group = self._merge_group(_group) + _shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + _shuffle_over_epochs_multiplier + ) + _subflavors = self._merge_subflavors(_subflavors) flattened: List[TraversedDatasetReference] = [] for dataset in self.blend: flattened.extend( dataset.traverse( mds_path, split_part=split_part, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, _subflavors=_subflavors, ) ) @@ -606,9 +661,15 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: subset = self._get_subset(subset) + group = self._merge_group(group) + subflavors = self._merge_subflavors(subflavors) + shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + shuffle_over_epochs_multiplier + ) sum_weight = sum(dataset.weight for dataset in self.blend) datasets = [] for dataset in self.blend: @@ -619,6 +680,7 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) if inner_result.blend_mode not in ( @@ -660,13 +722,25 @@ class BlendEpochizedJoinDatasetReference(BlendRepetitionsMixin, MetadatasetJoin) @edataclass -class MetadatasetBlendEpochized(SubsetRatioMixin, DatasetLoaderInterface): +class MetadatasetBlendEpochized( + SubsetRatioMixin, + GroupMixin, + ShuffleOverEpochsMultiplierMixin, + SubflavorsMixin, + DatasetLoaderInterface, +): """Blending of datasets, by specifying the number of repetitions for samples from the inner datasets. Ensures that the constraint, that samples are seen exactly this many times before repeating the "epoch" (i.e. one epoch contains the total number of repetitions for each inner dataset).""" - blend_epochized: List[Union[BlendEpochizedDatasetReference, BlendEpochizedJoinDatasetReference]] + blend_epochized: List[ + Union[ + BlendEpochizedDatasetReference, + BlendEpochizedJoinDatasetReference, + "MetadatasetBlendEpochized", + ] + ] def post_initialize(self, mds_path: Optional[EPath] = None): assert mds_path is not None @@ -678,15 +752,24 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: assert mds_path is not None flattened: List[TraversedDatasetReference] = [] + _group = self._merge_group(_group) + _shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + _shuffle_over_epochs_multiplier + ) + _subflavors = self._merge_subflavors(_subflavors) for dataset in self.blend_epochized: flattened.extend( dataset.traverse( mds_path, split_part=split_part, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, _subflavors=_subflavors, ) ) @@ -707,9 +790,15 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: subset = self._get_subset(subset) + shuffle_over_epochs_multiplier = self._merge_shuffle_over_epochs_multiplier( + shuffle_over_epochs_multiplier + ) + group = self._merge_group(group) + subflavors = self._merge_subflavors(subflavors) datasets = [] for dataset in self.blend_epochized: inner_result = dataset.get_datasets( @@ -719,6 +808,7 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) if inner_result.blend_mode not in ( @@ -760,6 +850,8 @@ def traverse( mds_path: Optional[EPath] = None, *, split_part: Union[Literal["train", "val", "test"], str], + _group: Optional[str] = None, + _shuffle_over_epochs_multiplier: Optional[int] = 1, _subflavors: Optional[Dict[str, Any]] = None, ) -> List[TraversedDatasetReference]: """Traverse the selected V2 split and flatten all reachable leaf references. @@ -775,6 +867,8 @@ def traverse( return self.splits[split_part].traverse( self.path, split_part=split_part, + _group=_group, + _shuffle_over_epochs_multiplier=_shuffle_over_epochs_multiplier, _subflavors=_subflavors, ) @@ -808,6 +902,7 @@ def get_datasets( subflavors: Optional[Dict[str, Any]] = None, shuffle_over_epochs_multiplier: Optional[int] = 1, subset: Optional[DatasetSubset] = None, + group: Optional[str] = None, **kwargs, ) -> LoadedDatasetList: return self.splits[split_part].get_datasets( @@ -817,5 +912,6 @@ def get_datasets( subflavors=subflavors, shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier, subset=subset, + group=group, **kwargs, ) diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index 5335d018..9032bc30 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -54,8 +54,10 @@ LogSampleDataset, MapDataset, PackingDataset, + PartialSample, ShuffleBufferDataset, ) +from megatron.energon.wrappers.packing_dataset import PackedSamplesOutput from megatron.energon.wrappers.repeat_dataset import RepeatDataset T = TypeVar("T") @@ -449,10 +451,14 @@ def preencode_sample( return sample @stateless - def postencode_sample(self, sample: T_sample) -> T_encoded_sample: + def postencode_sample( + self, sample: T_sample | PartialSample[T_sample, Any] + ) -> T_encoded_sample: """Post-encode a single sample. May raise :exc:`megatron.energon.SkipSample` to skip a sample. Alternatively, this can be a generator that yields (or ignores) new samples. Use in conjunction with packing and caching. + When partial samples are returned by :meth:`select_samples_to_pack`, this method must + handle both full samples and :class:`PartialSample` inputs. If this is defined, :func:`encode_sample` must not be defined. """ return sample @@ -535,90 +541,125 @@ def _batch( raise ValueError("Unrecognized result type.") def select_samples_to_pack( - self, samples: List[T_encoded_sample] - ) -> List[List[T_encoded_sample]]: + self, samples: List[T_encoded_sample | PartialSample[T_encoded_sample, Any]] + ) -> ( + list[list[T_encoded_sample | PartialSample[T_encoded_sample, Any]]] + | PackedSamplesOutput[T_encoded_sample | PartialSample[T_encoded_sample, Any]] + ): """ For packing, selects the samples to be packed together. Packing is only active when packing_buffer_size is set. Internally this stage is called "pre_packing". Args: - samples: The samples to pre-pack. A full buffer will be passed into the function. + samples: The samples to pre-pack (a full reading buffer per call when ``packing_buffer_size`` is set). - Returns: The pre-packed samples as a list of lists of samples. + Returns: + Either a ``list[list[T]]`` of packs, or :class:`PackedSamplesOutput` + to attach a ``pushback`` sequence reapplied to the reading buffer before the next fill. + Packs and pushback may contain :class:`PartialSample` values for user-defined slices. """ raise NotImplementedError("Packing only effective when overridden.") - def pack_selected_samples(self, samples: List[T_encoded_sample]) -> T_encoded_sample: + def pack_selected_samples( + self, samples: List[T_encoded_sample | PartialSample[T_encoded_sample, Any]] + ) -> T_encoded_sample: """ Given one set of samples to pack, returns the final packed sample. Packing is only active when packing_buffer_size is set. Internally this stage is called "final_packing". Args: - samples: The samples to pack into a single sample + samples: The samples to pack into a single sample. If partial samples were selected + and no post-encoding step is configured, this list may contain + :class:`PartialSample` values. Returns: The final packed sample. """ raise NotImplementedError("Packing only effective when overridden.") - def build_batch( + def _build_packing_postencode( self, dataset: SavableDataset[T_encoded_sample], *, - batch_size: Optional[int], - batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, + group: Optional[str], + packing_buffer_size: int | dict[str | None, int | None] | None, worker_config: WorkerConfig, - ) -> SavableDataset[T_raw_batch]: - """Applies the batcher to the dataset.""" - - dataset: SavableDataset[Any] + ) -> SavableDataset[T_encoded_sample]: + """Builds the (packing +) post-encode stage after encoding. - if packing_buffer_size is not None: - select_samples_to_pack_provided = self._is_overridden(self.select_samples_to_pack) - pack_selected_samples_provided = self._is_overridden(self.pack_selected_samples) + When ``packing_buffer_size`` is a dict, selects the buffer size (or ``None`` to disable + packing) for this leaf's dataset group via ``group`` (must match + :attr:`~megatron.energon.metadataset.loader_interface.LoadedDataset.group`). - assert select_samples_to_pack_provided and pack_selected_samples_provided, ( - "Both select_samples_to_pack and pack_selected_samples methods must be provided in the TaskEncoder when using packing_buffer_size" - ) + Args: + dataset: Encoded sample stream for one group. + group: Key into ``packing_buffer_size`` when it is a dict; ``None`` is the default group. + packing_buffer_size: Global buffer size, per-group mapping, or ``None`` to disable packing. + worker_config: Worker configuration for wrapped datasets. + """ + if isinstance(packing_buffer_size, dict): + packing_buffer_size = packing_buffer_size[group] + if packing_buffer_size is None: if self._is_overridden(self.postencode_sample): - post_encode_fn = self.postencode_sample - else: - post_encode_fn = None + dataset = MapDataset( + dataset, + self.postencode_sample, + worker_config=worker_config, + stateless_map_fn=get_stateless(self.postencode_sample), + failure_tolerance=get_failure_tolerance( + self.postencode_sample, self.__default_failure_tolerance__ + ), + ) + return dataset - dataset = PackingDataset( - dataset, - buffer_size=packing_buffer_size, - pre_packer=self.select_samples_to_pack, - final_packer=self.pack_selected_samples, - final_packer_stateless=get_stateless(self.pack_selected_samples), - sample_encoder=post_encode_fn, - sample_encoder_stateless=True - if post_encode_fn is None - else get_stateless(post_encode_fn), - worker_config=worker_config, - pre_packer_failure_tolerance=get_failure_tolerance( - self.select_samples_to_pack, self.__default_failure_tolerance__ - ), - final_packer_failure_tolerance=get_failure_tolerance( - self.pack_selected_samples, self.__default_failure_tolerance__ - ), - sample_encoder_failure_tolerance=0 - if post_encode_fn is None - else get_failure_tolerance(post_encode_fn, self.__default_failure_tolerance__), - ) - elif self._is_overridden(self.postencode_sample): - dataset = MapDataset( - dataset, - self.postencode_sample, - worker_config=worker_config, - stateless_map_fn=get_stateless(self.postencode_sample), - failure_tolerance=get_failure_tolerance( - self.postencode_sample, self.__default_failure_tolerance__ - ), + select_samples_to_pack_provided = self._is_overridden(self.select_samples_to_pack) + pack_selected_samples_provided = self._is_overridden(self.pack_selected_samples) + + assert select_samples_to_pack_provided and pack_selected_samples_provided, ( + "Both select_samples_to_pack and pack_selected_samples methods must be provided in the TaskEncoder when using packing_buffer_size" + ) + + if self._is_overridden(self.postencode_sample): + post_encode_fn = self.postencode_sample + post_encode_stateless = get_stateless(self.postencode_sample) + post_encode_failure_tolerance = get_failure_tolerance( + self.postencode_sample, self.__default_failure_tolerance__ ) + else: + post_encode_fn = None + post_encode_stateless = True + post_encode_failure_tolerance = 0 + + return PackingDataset( + dataset, + buffer_size=packing_buffer_size, + pre_packer=self.select_samples_to_pack, + final_packer=self.pack_selected_samples, + final_packer_stateless=get_stateless(self.pack_selected_samples), + sample_encoder=post_encode_fn, + sample_encoder_stateless=post_encode_stateless, + worker_config=worker_config, + pre_packer_failure_tolerance=get_failure_tolerance( + self.select_samples_to_pack, self.__default_failure_tolerance__ + ), + final_packer_failure_tolerance=get_failure_tolerance( + self.pack_selected_samples, self.__default_failure_tolerance__ + ), + sample_encoder_failure_tolerance=post_encode_failure_tolerance, + ) + + def build_batch( + self, + dataset: SavableDataset[T_encoded_sample], + *, + batch_size: Optional[int], + batch_drop_last: bool = False, + worker_config: WorkerConfig, + ) -> SavableDataset[T_raw_batch]: + """Applies the batcher to the dataset.""" + dataset: SavableDataset[Any] if self._is_overridden(self.batch_group_criterion): dataset = GroupBatchDataset( @@ -769,33 +810,36 @@ def build_encode_sample( ) return dataset - def build_train_datasets( + def _group_weight( + self, + group_ds: List[LoadedDataset], + blend_mode: DatasetBlendMode, + *, + repeat: bool, + ) -> float: + """Blend weight for one dataset group when merging packed streams after grouping.""" + if blend_mode == DatasetBlendMode.DATASET_WEIGHT: + return sum(float(d.weight) for d in group_ds) + if blend_mode == DatasetBlendMode.SAMPLE_REPETITIONS or ( + not repeat and blend_mode == DatasetBlendMode.NONE + ): + return sum( + len(d.dataset) * (1 if d.repetitions is None else float(d.repetitions)) + for d in group_ds + ) + return float(len(group_ds)) + + def _build_train_blend_shuffle_encode_branch( self, *, datasets: List[LoadedDataset], + worker_rotation_offsets: List[int], + blend_mode: DatasetBlendMode, + repeat: bool, + shuffle_buffer_size: Optional[int], worker_config: WorkerConfig, - batch_size: Optional[int], - batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, - virtual_epoch_length: int = 0, - shuffle_buffer_size: Optional[int] = None, - blend_mode: DatasetBlendMode = DatasetBlendMode.NONE, - repeat: bool = True, - ) -> SavableDataset[T_batch]: - """Combines train datasets to a single dataset.""" - - # Check if there's a CrudeWebdataset but no cookers - for dataset in datasets: - if isinstance(dataset.dataset, CrudeWebdataset): - assert self.cookers, "CrudeWebdataset found, but no cookers registered." - - global_workers = max(1, worker_config.num_workers) * worker_config.world_size - rotation_lengths = [len(dataset.dataset) for dataset in datasets] - for i in range(1, len(rotation_lengths)): - rotation_lengths[i] += rotation_lengths[i - 1] - worker_rotation_offsets = [ - rotation_length % global_workers for rotation_length in [0] + rotation_lengths[:-1] - ] + ) -> SavableDataset[T_encoded_sample]: + """Builds the (blend) → (repeat/shuffle) → (preencode/encode) pipeline.""" if blend_mode == DatasetBlendMode.DATASET_WEIGHT: assert repeat, ( @@ -874,14 +918,134 @@ def build_train_datasets( size=shuffle_buffer_size, worker_config=worker_config, ) - dataset = self.build_encode_sample(dataset, worker_config=worker_config) + return self.build_encode_sample(dataset, worker_config=worker_config) + + def _compute_rotation_offsets( + self, datasets: List[LoadedDataset], worker_config: WorkerConfig + ) -> List[int]: + global_workers = max(1, worker_config.num_workers) * worker_config.world_size + rotation_lengths = [len(d.dataset) for d in datasets] + for i in range(1, len(rotation_lengths)): + rotation_lengths[i] += rotation_lengths[i - 1] + return [rotation_length % global_workers for rotation_length in [0] + rotation_lengths[:-1]] + + def _build_train_blend_shuffle_encode_groups( + self, + datasets: List[LoadedDataset], + packing_buffer_size: Optional[int | dict[str | None, int | None]], + blend_mode: DatasetBlendMode, + repeat: bool, + shuffle_buffer_size: Optional[int | dict[str | None, int | None]], + worker_config: WorkerConfig, + ) -> SavableDataset[T_encoded_sample]: + """Builds the train pipeline with optional per-dataset-group isolation. + + Splits ``datasets`` by :attr:`~megatron.energon.metadataset.loader_interface.LoadedDataset.group`. + For each group, runs blend → (optional shuffle) → encode, then applies packing/postencode for + that group's ``packing_buffer_size`` / ``shuffle_buffer_size`` entries. When multiple groups + exist, blends the resulting streams with weights from :meth:`_group_weight`. + + Pipeline per group: + ``blend → shuffle → encode → select_samples_to_pack → postencode → pack_selected_samples``. + Multiple groups: ``(... per group ...) → blend``. + """ + rotation_offsets = self._compute_rotation_offsets(datasets, worker_config) + + dataset_groups: dict[Optional[str], tuple[list[LoadedDataset], list[int]]] = {} + for ld, ro in zip(datasets, rotation_offsets): + if ld.group in dataset_groups: + dataset_groups[ld.group][0].append(ld) + dataset_groups[ld.group][1].append(ro) + else: + dataset_groups[ld.group] = ([ld], [ro]) + + streams: List[tuple[SavableDataset[Any], float]] = [] + for group_key, (group_ds, rotation_offsets) in dataset_groups.items(): + if isinstance(shuffle_buffer_size, dict): + group_shuffle_buffer_size = shuffle_buffer_size[group_key] + else: + group_shuffle_buffer_size = shuffle_buffer_size + + dataset = self._build_train_blend_shuffle_encode_branch( + datasets=group_ds, + worker_rotation_offsets=rotation_offsets, + blend_mode=blend_mode, + repeat=repeat, + shuffle_buffer_size=group_shuffle_buffer_size, + worker_config=worker_config, + ) + # Post-encode is included + dataset = self._build_packing_postencode( + dataset, + group=group_key, + packing_buffer_size=packing_buffer_size, + worker_config=worker_config, + ) + streams.append( + ( + dataset, + self._group_weight(group_ds, blend_mode, repeat=repeat), + ) + ) + + if len(streams) > 1: + return BlendDataset(*streams, worker_config=worker_config) + else: + return streams[0][0] + + def build_train_datasets( + self, + *, + datasets: List[LoadedDataset], + worker_config: WorkerConfig, + batch_size: Optional[int], + batch_drop_last: bool = False, + packing_buffer_size: Optional[int | dict[str | None, int | None]] = None, + virtual_epoch_length: int = 0, + shuffle_buffer_size: Optional[int | dict[str | None, int | None]] = None, + blend_mode: DatasetBlendMode = DatasetBlendMode.NONE, + repeat: bool = True, + ) -> SavableDataset[T_batch]: + """Combines train datasets into one batched dataset pipeline. + + Args: + datasets: Loaded leaf datasets (each carries a ``group`` key when using Metadataset V2). + worker_config: Worker configuration for wrapped datasets. + batch_size: Batch dimension; ``None`` skips batching. + batch_drop_last: If true, drop the last batch when smaller than ``batch_size``. + packing_buffer_size: Packing buffer size, or a dict mapping dataset group keys + (including ``None`` for the default group) to sizes or ``None`` to disable packing + per group. + virtual_epoch_length: If positive, wraps with epochization at this length. + shuffle_buffer_size: Shuffle buffer before encoding, or per-group dict like + ``packing_buffer_size``. + blend_mode: How leaf weights map to the inner :class:`~megatron.energon.BlendDataset`. + repeat: Whether inner datasets loop indefinitely. + + Returns: + The full train :class:`~megatron.energon.flavors.SavableDataset` pipeline. + """ + + # Check if there's a CrudeWebdataset but no cookers + for dataset in datasets: + if isinstance(dataset.dataset, CrudeWebdataset): + assert self.cookers, "CrudeWebdataset found, but no cookers registered." + + dataset = self._build_train_blend_shuffle_encode_groups( + datasets=datasets, + packing_buffer_size=packing_buffer_size, + blend_mode=blend_mode, + repeat=repeat, + shuffle_buffer_size=shuffle_buffer_size, + worker_config=worker_config, + ) dataset = self.build_batch( dataset, batch_size=batch_size, batch_drop_last=batch_drop_last, - packing_buffer_size=packing_buffer_size, worker_config=worker_config, ) + if virtual_epoch_length > 0: dataset = EpochizeDataset( dataset, @@ -893,31 +1057,13 @@ def build_train_datasets( return dataset - def build_val_datasets( + def _build_val_concat_encode_branch( self, - *, datasets: List[LoadedDataset], + worker_rotation_offsets: List[int], worker_config: WorkerConfig, - batch_size: int, - batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, - limit: Optional[int] = None, - ) -> SavableDataset[T_batch]: - """Combines val datasets to a single dataset.""" - - # Check if there's a CrudeWebdataset but no cookers - for dataset in datasets: - if isinstance(dataset, CrudeWebdataset): - assert self.cookers, "CrudeWebdataset found, but no cookers registered." - - global_workers = max(1, worker_config.num_workers) * worker_config.world_size - rotation_lengths = [len(dataset.dataset) for dataset in datasets] - for i in range(1, len(rotation_lengths)): - rotation_lengths[i] += rotation_lengths[i - 1] - worker_rotation_offsets = [ - rotation_length % global_workers for rotation_length in [0] + rotation_lengths[:-1] - ] - + ) -> SavableDataset[T_encoded_sample]: + """Builds the (concat) → (preencode/encode) pipeline.""" if len(datasets) > 1: dataset = ConcatDataset( *[ @@ -930,12 +1076,93 @@ def build_val_datasets( dataset = self._load_dataset(datasets[0], worker_rotation_offsets[0], worker_config) else: raise ValueError("No datasets given.") - dataset = self.build_encode_sample(dataset, worker_config=worker_config) + return self.build_encode_sample(dataset, worker_config=worker_config) + + def _build_val_concat_encode_groups( + self, + datasets: List[LoadedDataset], + packing_buffer_size: Optional[int | dict[str | None, int | None]], + worker_config: WorkerConfig, + ) -> SavableDataset[T_encoded_sample]: + """Builds the validation pipeline with optional per-dataset-group isolation. + + Like :meth:`_build_train_blend_shuffle_encode_groups`, but concatenates leaves instead of + blending, and omits shuffle/repeat. Splits ``datasets`` by ``LoadedDataset.group``, applies + packing per group's ``packing_buffer_size`` entry, then concatenates group streams when needed. + + Pipeline per group: + ``concat loaded leaves → encode → select_samples_to_pack → postencode → pack_selected_samples``. + Multiple groups: ``(... per group ...) → concat``. + """ + rotation_offsets = self._compute_rotation_offsets(datasets, worker_config) + + dataset_groups: dict[Optional[str], tuple[list[LoadedDataset], list[int]]] = {} + for ld, ro in zip(datasets, rotation_offsets): + if ld.group in dataset_groups: + dataset_groups[ld.group][0].append(ld) + dataset_groups[ld.group][1].append(ro) + else: + dataset_groups[ld.group] = ([ld], [ro]) + + streams: List[SavableDataset[Any]] = [] + for group_key, (group_ds, rotation_offsets) in dataset_groups.items(): + branch = self._build_val_concat_encode_branch( + datasets=group_ds, + worker_rotation_offsets=rotation_offsets, + worker_config=worker_config, + ) + # Post-encode is included + branch = self._build_packing_postencode( + branch, + group=group_key, + packing_buffer_size=packing_buffer_size, + worker_config=worker_config, + ) + streams.append(branch) + if len(streams) > 1: + return ConcatDataset(*streams, worker_config=worker_config) + else: + return streams[0] + + def build_val_datasets( + self, + *, + datasets: List[LoadedDataset], + worker_config: WorkerConfig, + batch_size: int, + batch_drop_last: bool = False, + packing_buffer_size: Optional[int | dict[str | None, int | None]] = None, + limit: Optional[int] = None, + ) -> SavableDataset[T_batch]: + """Combines validation datasets into one batched dataset pipeline. + + Args: + datasets: Loaded leaf datasets (each may carry a ``group`` key when using Metadataset V2). + worker_config: Worker configuration for wrapped datasets. + batch_size: Batch dimension. + batch_drop_last: If true, drop the last batch when smaller than ``batch_size``. + packing_buffer_size: Packing buffer size, or per-group dict (keys include ``None`` for the + default group). + limit: If set and positive, caps the number of batches via :class:`~megatron.energon.LimitDataset`. + + Returns: + The full validation :class:`~megatron.energon.flavors.SavableDataset` pipeline. + """ + + # Check if there's a CrudeWebdataset but no cookers + for dataset in datasets: + if isinstance(dataset, CrudeWebdataset): + assert self.cookers, "CrudeWebdataset found, but no cookers registered." + + dataset = self._build_val_concat_encode_groups( + datasets=datasets, + packing_buffer_size=packing_buffer_size, + worker_config=worker_config, + ) dataset = self.build_batch( dataset, batch_size=batch_size, batch_drop_last=batch_drop_last, - packing_buffer_size=packing_buffer_size, worker_config=worker_config, ) if limit is not None and limit > 0: diff --git a/src/megatron/energon/task_encoder/loader.py b/src/megatron/energon/task_encoder/loader.py index 5c680bb2..4e3d0bc7 100644 --- a/src/megatron/energon/task_encoder/loader.py +++ b/src/megatron/energon/task_encoder/loader.py @@ -111,8 +111,8 @@ def get_train_dataset( worker_config: WorkerConfig, batch_size: Optional[int], batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, - shuffle_buffer_size: Optional[int], + packing_buffer_size: Optional[int | dict[str | None, int | None]] = None, + shuffle_buffer_size: Optional[int | dict[str | None, int | None]], max_samples_per_sequence: Optional[int], virtual_epoch_length: int = 0, shuffle_over_epochs_multiplier: Optional[int] = 1, @@ -137,7 +137,12 @@ def get_train_dataset( worker_config: Worker configuration to use. batch_size: Size of a batch. If None, do not batch batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`. - shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding). + packing_buffer_size: Size of the packing buffer. If a dict, keys are dataset group names from + Metadataset V2 (YAML ``group``, merged along the path; ``None`` is the default group). + Values are buffer sizes, or ``None`` to disable packing for that group. + shuffle_buffer_size: Sample shuffle buffer size before task encoding. If a dict, keys are the + same dataset group names as for ``packing_buffer_size``; values are buffer sizes or ``None`` + to disable shuffling for that group. max_samples_per_sequence: If set, limit the number of samples per sample-sequence to this. virtual_epoch_length: If set, the dataset will be epochized to this length (=iterating will be suspended and the for-loop returns, next for-loop continues iterating). @@ -186,7 +191,7 @@ def get_val_dataset( worker_config: WorkerConfig, batch_size: int, batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, + packing_buffer_size: Optional[int | dict[str | None, int | None]] = None, limit: Optional[int] = None, task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(), **kwargs, @@ -208,6 +213,9 @@ def get_val_dataset( worker_config: Worker configuration to use. batch_size: Size of a batch batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`. + packing_buffer_size: Size of the packing buffer. If a dict, keys are dataset group names from + Metadataset V2 (same as ``get_train_dataset``); values are buffer sizes or ``None`` to + disable packing per group. limit: If set, limit the number of batches loaded from the dataset to this. task_encoder: Task encoder to use. **kwargs: Additional arguments to the dataset constructor. @@ -241,7 +249,7 @@ def get_val_datasets( worker_config: WorkerConfig, batch_size: int, batch_drop_last: bool = False, - packing_buffer_size: Optional[int] = None, + packing_buffer_size: Optional[int | dict[str | None, int | None]] = None, limit: Optional[int] = None, task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(), **kwargs, @@ -263,6 +271,9 @@ def get_val_datasets( worker_config: Worker configuration to use. batch_size: Size of a batch batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`. + packing_buffer_size: Size of the packing buffer. If a dict, keys are dataset group names from + Metadataset V2 (same as ``get_train_dataset``); values are buffer sizes or ``None`` to + disable packing per group. limit: If set, limit the number of batches loaded from the dataset to this. task_encoder: Task encoder to use. **kwargs: Additional arguments to the dataset constructor. diff --git a/src/megatron/energon/wrappers/__init__.py b/src/megatron/energon/wrappers/__init__.py index 808ca50e..a0fe1273 100644 --- a/src/megatron/energon/wrappers/__init__.py +++ b/src/megatron/energon/wrappers/__init__.py @@ -18,7 +18,11 @@ generic_concat, homogeneous_concat_mix, ) -from megatron.energon.wrappers.packing_dataset import PackingDataset +from megatron.energon.wrappers.packing_dataset import ( + PackedSamplesOutput, + PackingDataset, + PartialSample, +) from megatron.energon.wrappers.repeat_dataset import RepeatDataset from megatron.energon.wrappers.shuffle_buffer_dataset import ShuffleBufferDataset from megatron.energon.wrappers.skip import SkipSample @@ -40,6 +44,8 @@ "ShuffleBufferDataset", "SkipSample", "PackingDataset", + "PackedSamplesOutput", + "PartialSample", "concat_pad", "generic_concat", "homogeneous_concat_mix", diff --git a/src/megatron/energon/wrappers/base.py b/src/megatron/energon/wrappers/base.py index 2f4dba89..9266d44d 100644 --- a/src/megatron/energon/wrappers/base.py +++ b/src/megatron/energon/wrappers/base.py @@ -182,7 +182,7 @@ def restore_state(self, state: Optional[int]) -> None: self.current_idx = state -def get_sample_restore_key(sample: Any) -> Optional[Union[str, int]]: +def get_sample_restore_key(sample: Any) -> Optional[Union[str, int, tuple]]: """Gets the restore key from an arbitrary sample.""" if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): return sample.__restore_key__ diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index ee10e6c2..36e479e0 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -3,6 +3,7 @@ import contextlib import inspect +from dataclasses import dataclass, field from typing import ( Any, Callable, @@ -12,10 +13,12 @@ Iterator, List, Optional, + Sequence, TypeVar, Union, ) +from megatron.energon.edataclass import edataclass from megatron.energon.errors import ErrorContext, handle_restore_errors from megatron.energon.flavors.base_dataset import ( SavableDataset, @@ -23,35 +26,93 @@ set_sample_restore_key, ) from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + SampleIndex, + get_sample_restore_key, +) from megatron.energon.wrappers.buffer import SavableSampleBuffer T_sample = TypeVar("T_sample") T_encoded_sample = TypeVar("T_encoded_sample") T_batch_sample = TypeVar("T_batch_sample") +T_slice = TypeVar("T_slice") + + +@edataclass +class PartialSample(Generic[T_sample, T_slice]): + """A view onto part of a sample selected during packing.""" + + sample: T_sample + slice: T_slice + __restore_key__: tuple | None = field(init=False, repr=False, default=None) + + def __post_init__(self) -> None: + inner_restore_key = get_sample_restore_key(self.sample) + if inner_restore_key is not None: + set_sample_restore_key(self, inner_restore_key, self.slice, src=self) + + +class SavablePartialSampleBuffer( + SavableSampleBuffer[T_sample | PartialSample[T_sample, T_slice]], + Generic[T_sample, T_slice], +): + """Sample buffer that can save and restore packing-only partial sample wrappers.""" + + def restore_sample(self, restore_key: Any) -> T_sample | PartialSample[T_sample, T_slice]: + if ( + isinstance(restore_key, (tuple, list)) + and len(restore_key) == 3 + and restore_key[0] == PartialSample.__name__ + ): + _, sample_restore_key, sample_slice = restore_key + return PartialSample( + sample=self.restore_sample(sample_restore_key), + slice=sample_slice, + ) + return self.dataset.restore_sample(restore_key) + + +@dataclass(slots=True, frozen=True) +class PackedSamplesOutput(Generic[T_sample]): + """Return type for :meth:`~megatron.energon.TaskEncoder.select_samples_to_pack` when including + samples to push back onto the packing reading buffer.""" + + #: The packs of samples to be packed together. + # One entry per pack, each containing the samples to be packed together. + packs: list[list[T_sample]] + + #: The samples to push back onto the packing reading buffer for the next fill. + pushback: Sequence[T_sample] = field(default_factory=tuple) class PackingDataset( BaseWrapperDataset[T_sample, T_batch_sample], - Generic[T_sample, T_encoded_sample, T_batch_sample], + Generic[T_sample, T_encoded_sample, T_batch_sample, T_slice], ): """This dataset wrapper transforms samples of a dataset into chunks/packs of samples, which are then combined into a batch.""" buffer_size: int - pre_packer: Callable[[List[T_sample]], List[List[T_sample]]] - sample_encoder: Optional[Callable[[T_sample], T_encoded_sample]] + pre_packer: Callable[ + [list[T_sample | PartialSample[T_sample, Any]]], + list[list[T_sample | PartialSample[T_sample, Any]]] + | PackedSamplesOutput[T_sample | PartialSample[T_sample, Any]], + ] + sample_encoder: Optional[Callable[[T_sample | PartialSample[T_sample, Any]], T_encoded_sample]] sample_encoder_stateless: bool - final_packer: Callable[[List[T_encoded_sample]], T_batch_sample] + final_packer: Callable[ + [List[T_encoded_sample | T_sample | PartialSample[T_sample, Any]]], T_batch_sample + ] final_packer_stateless: bool packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] #: The buffer for collecting the samples that shall be packed. - _reading_buffer: SavableSampleBuffer + _reading_buffer: SavablePartialSampleBuffer[T_sample, T_slice] #: Contains the pre-selected samples to be packed. #: The full buffer will be passed to the pre_packer. - _pre_packing_buffer: SavableSampleBuffer + _pre_packing_buffer: SavablePartialSampleBuffer[T_sample, T_slice] #: Lengths of the selected groups of samples to be packed together. #: The samples are stored sequentially in the pre_packing_buffer because @@ -86,11 +147,19 @@ def __init__( self, dataset: SavableDataset[T_sample], buffer_size: int, - pre_packer: Callable[[List[T_sample]], List[List[T_sample]]], - final_packer: Callable[[List[T_encoded_sample]], T_batch_sample], + pre_packer: Callable[ + [list[T_sample | PartialSample[T_sample, T_slice]]], + list[list[T_sample | PartialSample[T_sample, T_slice]]] + | PackedSamplesOutput[T_sample | PartialSample[T_sample, T_slice]], + ], + final_packer: Callable[ + [List[T_encoded_sample | T_sample | PartialSample[T_sample, T_slice]]], T_batch_sample + ], *, final_packer_stateless: bool = False, - sample_encoder: Optional[Callable[[T_sample], T_encoded_sample]] = None, + sample_encoder: Optional[ + Callable[[T_sample | PartialSample[T_sample, T_slice]], T_encoded_sample] + ] = None, sample_encoder_stateless: bool = False, packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, pre_packer_failure_tolerance: int = 100, @@ -108,7 +177,8 @@ def __init__( dataset: The input dataset to wrap buffer_size: The desired size of the input buffer for pre packing. Last buffer of a dataset may be smaller. pre_packer: Function which selects samples from the buffer to be packed together. - May raise :exc:`megatron.energon.SkipSample` to skip a buffer. + May return :class:`PackedSamplesOutput` to include a ``pushback`` sequence + reapplied to the reading buffer. May raise :exc:`megatron.energon.SkipSample` to skip a buffer. final_packer: Function which combines the selected samples into a single sample. final_packer_stateless: If True, the final_packer is stateless, thus samples can be stored/restored. @@ -160,9 +230,13 @@ def __init__( self.reset_state_own() def reset_state_own(self) -> None: - self._reading_buffer = SavableSampleBuffer(self.dataset, worker_config=self.worker_config) - self._pre_packing_buffer = SavableSampleBuffer( - self.dataset, worker_config=self.worker_config + self._reading_buffer = SavablePartialSampleBuffer( + self.dataset, + worker_config=self.worker_config, + ) + self._pre_packing_buffer = SavablePartialSampleBuffer( + self.dataset, + worker_config=self.worker_config, ) self._pre_packing_lengths = [] self._pre_packing_sample_index = SampleIndex(self.worker_config, src=self) @@ -218,7 +292,9 @@ def __iter__(self) -> Iterator[T_batch_sample]: is_initial_pack = True - def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: + def encode_pack_samples( + pack: List[T_sample | PartialSample[T_sample, Any]], + ) -> List[T_encoded_sample | T_sample | PartialSample[T_sample, Any]]: """Encode the samples in the pack using the sample encoder.""" # Apply the sample encoder to the pack @@ -226,10 +302,17 @@ def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: return pack encoded_pack = [] for sample in pack: + input_restore_key = get_sample_restore_key(sample) with self._sample_encoder_failure_handler.handle_errors(sample): with self._sample_encoder_sample_index.ctx() as encode_idx: encoded_sample = self.sample_encoder(sample) assert not isinstance(encoded_sample, Generator), "Generator not supported" + if isinstance(sample, PartialSample) and input_restore_key is not None: + encoded_sample = set_sample_restore_key( + encoded_sample, + *input_restore_key[1:], + src=sample, + ) self._sample_encoder_failure_handler.reset() encoded_pack.append( add_sample_restore_key( @@ -245,26 +328,35 @@ def next_pre_pack(): together.""" assert self._pre_packing_buffer.len_worker() == 0 - if self._reading_buffer.len_worker() > 0: - # Take all samples from the reading buffer and pre_pack them - samples = self._reading_buffer.buffer.copy() - # Clear buffer and pre_packing_lengths - self._reading_buffer.clear() - pre_packing_lengths.clear() - # Now pre pack the samples - pre_packs = [] - with self._pre_pack_failure_handler.handle_errors(samples): - with self._pre_packing_sample_index.ctx(): - pre_packs = self.pre_packer(samples) - - # Put the pre-packed samples into the pre_packing_buffer - # They will be flattened here to avoid nested buffers - # But the lengths of the groups are stored in pre_packing_lengths - # so that the groups can be separated later - for pre_pack in pre_packs: - if len(pre_pack) > 0: - self._pre_packing_buffer.extend(pre_pack) - pre_packing_lengths.append(len(pre_pack)) + if self._reading_buffer.len_worker() == 0: + return + + # Take all samples from the reading buffer and pre_pack them + samples = self._reading_buffer.buffer.copy() + # Clear buffer and pre_packing_lengths + self._reading_buffer.clear() + pre_packing_lengths.clear() + # Now pre pack the samples + pre_packs_result: list[list[T_sample]] | PackedSamplesOutput[T_sample] | None = None + with self._pre_pack_failure_handler.handle_errors(samples): + with self._pre_packing_sample_index.ctx(): + pre_packs_result = self.pre_packer(samples) + if pre_packs_result is None: + # The error handler may actually catch the error + return + # Put the pre-packed samples into the pre_packing_buffer + # They will be flattened here to avoid nested buffers + # But the lengths of the groups are stored in pre_packing_lengths + # so that the groups can be separated later + if isinstance(pre_packs_result, PackedSamplesOutput): + pre_packs = pre_packs_result.packs + self._reading_buffer.extend(list(pre_packs_result.pushback)) + else: + pre_packs = pre_packs_result + for pre_pack in pre_packs: + if len(pre_pack) > 0: + self._pre_packing_buffer.extend(pre_pack) + pre_packing_lengths.append(len(pre_pack)) def next_final_pack() -> Generator[T_batch_sample, None, None]: """Yield the next packs from the buffer. The final packer is called on the fly.""" @@ -335,18 +427,26 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: yield from next_final_pack() - # Yield the remaining packs, flushing the collecting buffer - while len(pre_packing_lengths) > 0: - yield from next_final_pack() + # Yield the remaining packs, flushing carryover partials until the source is exhausted. + flush_round = 0 + while len(pre_packing_lengths) > 0 or self._reading_buffer.len_worker() > 0: + if len(pre_packing_lengths) > 0: + flush_round = 0 + yield from next_final_pack() + continue + # If there are still samples in the partial reading buffer, pre-pack them and yield the + # resulting (partial) packs - # If there are still samples in the partial reading buffer, pre-pack them and yield the - # resulting (partial) packs - if self._reading_buffer.len_worker() > 0: next_pre_pack() - - # Yield the remaining packs, flushing the collecting buffer - while len(pre_packing_lengths) > 0: - yield from next_final_pack() + if len(pre_packing_lengths) == 0 and self._reading_buffer.len_worker() > 0: + flush_round += 1 + if ( + self.pre_packer_failure_tolerance > 0 + and flush_round > self.pre_packer_failure_tolerance + ): + raise RuntimeError( + f"Pre packer {self.pre_packer} did not yield any packs after {flush_round} flush rounds. Likely your code or dataset are broken." + ) def can_restore_sample(self) -> bool: # Cannot really verify if the returned elements contain a __restore_key__. @@ -367,28 +467,32 @@ def restore_sample(self, restore_key: Any) -> T_sample: # We need to store multiple indices to restore a batch. self.assert_can_restore() if inspect.isgeneratorfunction(self.final_packer): - id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key assert id == type(self).__name__ else: - id, pack_idx, *pack_restore_keys = restore_key id, pack_idx, *pack_restore_keys = restore_key assert id == type(self).__name__ pack = [] for inner_idx in pack_restore_keys: if self.sample_encoder is not None: - id, sample_idx, *inner_idx = inner_idx - assert id == type(self).__name__ id, sample_idx, *inner_idx = inner_idx assert id == type(self).__name__ assert isinstance(sample_idx, int) - sample = self.dataset.restore_sample(inner_idx) + sample = self._reading_buffer.restore_sample(inner_idx) if self.sample_encoder is not None: with handle_restore_errors(self.worker_config.restore_error_handler, sample): + input_sample = sample + input_restore_key = get_sample_restore_key(input_sample) with self._sample_encoder_sample_index.ctx(sample_idx): sample = self.sample_encoder(sample) assert not isinstance(sample, Generator), "Generator not supported" + if isinstance(input_sample, PartialSample) and input_restore_key is not None: + sample = set_sample_restore_key( + sample, + *input_restore_key[1:], + src=input_sample, + ) sample = add_sample_restore_key(sample, sample_idx, src=self) pack.append(sample) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 01aa74ee..e0a99f28 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -34,6 +34,8 @@ DefaultTaskEncoder, MapDataset, MixBatchDataset, + PackedSamplesOutput, + PartialSample, Sample, SavableDataLoader, TaskEncoder, @@ -1561,6 +1563,351 @@ def pack_selected_samples( assert restored_sample_1.__key__ == samples[1].__key__ assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ + def test_packing_pushback(self): + """Pushback must prepend deferred samples ahead of new iterator pulls. + + With ``buffer_size`` 5, the first full buffer follows the dataset iterator order (not + necessarily sorted ``__key__``). We pack only the first sample and push the rest back; the + next pre-pack must see the four deferred samples first, then one newly drawn sample — i.e. + the second buffer's keys must start with the first buffer's keys at indices ``[1:5]``. + """ + torch.manual_seed(42) + + buf = 5 + + class TestTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) + self.buffer_keys_per_call: list[list[str]] = [] + self.first_fill_keys: list[str] | None = None + + @stateless + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8), + ) + + def select_samples_to_pack( + self, samples: list[EncodedCaptioningSample] + ) -> PackedSamplesOutput[EncodedCaptioningSample]: + keys = [s.__key__ for s in samples] + self.buffer_keys_per_call.append(keys.copy()) + call = len(self.buffer_keys_per_call) + if call == 1: + assert len(samples) == buf + self.first_fill_keys = keys.copy() + return PackedSamplesOutput( + packs=[samples[:1]], + pushback=tuple(samples[1:]), + ) + if call == 2: + assert len(samples) == buf + assert self.first_fill_keys is not None + assert keys[: buf - 1] == self.first_fill_keys[1:buf], ( + "Pushback did not appear before new samples: " + f"second_fill[:4]={keys[: buf - 1]!r} " + f"expected first_fill[1:5]={self.first_fill_keys[1:buf]!r}" + ) + return PackedSamplesOutput( + packs=[[s] for s in samples], + pushback=(), + ) + return PackedSamplesOutput( + packs=[[s] for s in samples], + pushback=(), + ) + + @stateless + def pack_selected_samples( + self, samples: list[EncodedCaptioningSample] + ) -> EncodedCaptioningSample: + return EncodedCaptioningSample( + __key__=",".join(sample.__key__ for sample in samples), + __restore_key__=(), + image=torch.stack([sample.image for sample in samples]), + caption=torch.cat([sample.caption for sample in samples]), + ) + + task_encoder = TestTaskEncoder() + loader = get_loader( + get_train_dataset( + self.dataset_path, + batch_size=2, + packing_buffer_size=buf, + worker_config=no_worker_config, + virtual_epoch_length=4, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=task_encoder, + ) + ) + + out = list(loader) + assert len(out) == 4 + assert task_encoder.first_fill_keys is not None + assert ( + task_encoder.buffer_keys_per_call[1][: buf - 1] == task_encoder.first_fill_keys[1:buf] + ) + restored = loader.restore_sample(out[0].__restore_key__) + assert restored.__key__ == out[0].__key__ + + def test_packing_partial_sample_carryover(self): + """Partial samples carry unconsumed token slices across packing rounds.""" + + torch.manual_seed(42) + max_tokens = 11 + buffer_size = 4 + + class TestTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) + + @staticmethod + def _window( + sample: EncodedCaptioningSample + | PartialSample[EncodedCaptioningSample, tuple[int, int]], + ) -> tuple[EncodedCaptioningSample, int, int]: + if isinstance(sample, PartialSample): + token_start, token_stop = sample.slice + return sample.sample, token_start, token_stop + return sample, 0, len(sample.caption) + + @stateless + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8), + ) + + def select_samples_to_pack( + self, + samples: List[ + EncodedCaptioningSample + | PartialSample[EncodedCaptioningSample, tuple[int, int]] + ], + ) -> PackedSamplesOutput[ + EncodedCaptioningSample + | PartialSample[EncodedCaptioningSample, tuple[int, int]] + ]: + pack: list[ + EncodedCaptioningSample + | PartialSample[EncodedCaptioningSample, tuple[int, int]] + ] = [] + pushback: list[ + EncodedCaptioningSample + | PartialSample[EncodedCaptioningSample, tuple[int, int]] + ] = [] + remaining = max_tokens + + for sample in samples: + base_sample, token_start, token_stop = self._window(sample) + token_count = token_stop - token_start + if remaining == 0: + pushback.append(sample) + elif token_count <= remaining: + pack.append(sample) + remaining -= token_count + else: + pack.append( + PartialSample( + sample=base_sample, + slice=(token_start, token_start + remaining), + ) + ) + pushback.append( + PartialSample( + sample=base_sample, + slice=(token_start + remaining, token_stop), + ) + ) + remaining = 0 + + return PackedSamplesOutput( + packs=[pack] if pack else [], + pushback=tuple(pushback), + ) + + @stateless + def postencode_sample( + self, + sample: EncodedCaptioningSample + | PartialSample[EncodedCaptioningSample, tuple[int, int]], + ) -> EncodedCaptioningSample: + base_sample, token_start, token_stop = self._window(sample) + return EncodedCaptioningSample.derive_from( + base_sample, + __key__=f"{base_sample.__key__}:{token_start}:{token_stop}", + image=base_sample.image, + caption=base_sample.caption[token_start:token_stop], + ) + + @stateless + def pack_selected_samples( + self, samples: List[EncodedCaptioningSample] + ) -> EncodedCaptioningSample: + return EncodedCaptioningSample( + __key__="|".join(sample.__key__ for sample in samples), + __restore_key__=(), + image=torch.stack([sample.image for sample in samples]), + caption=torch.cat([sample.caption for sample in samples]), + ) + + class FinalPackerTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) + + @staticmethod + def _window( + sample: EncodedCaptioningSample + | PartialSample[EncodedCaptioningSample, tuple[int, int]], + ) -> tuple[EncodedCaptioningSample, int, int]: + return TestTaskEncoder._window(sample) + + @stateless + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8), + ) + + def select_samples_to_pack( + self, + samples: List[ + EncodedCaptioningSample + | PartialSample[EncodedCaptioningSample, tuple[int, int]] + ], + ) -> PackedSamplesOutput[ + EncodedCaptioningSample + | PartialSample[EncodedCaptioningSample, tuple[int, int]] + ]: + return TestTaskEncoder.select_samples_to_pack(self, samples) + + @stateless + def pack_selected_samples( + self, + samples: List[ + EncodedCaptioningSample + | PartialSample[EncodedCaptioningSample, tuple[int, int]] + ], + ) -> EncodedCaptioningSample: + parts = [] + images = [] + captions = [] + for sample in samples: + base_sample, token_start, token_stop = self._window(sample) + parts.append(f"{base_sample.__key__}:{token_start}:{token_stop}") + images.append(base_sample.image) + captions.append(base_sample.caption[token_start:token_stop]) + return EncodedCaptioningSample( + __key__="|".join(parts), + __restore_key__=(), + image=torch.stack(images), + caption=torch.cat(captions), + ) + + def assert_segments(samples: list[EncodedCaptioningSample]) -> None: + by_key: dict[str, list[tuple[int, int, bytes]]] = defaultdict(list) + for sample in samples: + caption_offset = 0 + for part in sample.__key__.split("|"): + key, raw_token_start, raw_token_stop = part.split(":") + token_start = int(raw_token_start) + token_stop = int(raw_token_stop) + token_count = token_stop - token_start + segment = bytes( + sample.caption[caption_offset : caption_offset + token_count].tolist() + ) + expected = self.samples[int(key)]["caption"].encode()[token_start:token_stop] + assert segment == expected + by_key[key].append((token_start, token_stop, segment)) + caption_offset += token_count + assert caption_offset == len(sample.caption) + assert caption_offset == max_tokens + + complete_keys = 0 + for key, segments in by_key.items(): + expected = self.samples[int(key)]["caption"].encode() + if sum(token_stop - token_start for token_start, token_stop, _ in segments) != len( + expected + ): + continue + actual = bytearray(len(expected)) + covered = [False] * len(expected) + for token_start, token_stop, segment in segments: + actual[token_start:token_stop] = segment + covered[token_start:token_stop] = [True] * (token_stop - token_start) + assert all(covered) + assert bytes(actual) == expected + complete_keys += 1 + assert complete_keys > 0 + + def build_loader(): + return get_savable_loader( + get_train_dataset( + self.dataset_path, + batch_size=None, + packing_buffer_size=buffer_size, + worker_config=no_worker_config, + virtual_epoch_length=8, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=TestTaskEncoder(), + ), + checkpoint_every_min_n_samples=1, + checkpoint_every_sec=0, + ) + + loader = build_loader() + loader_iter = iter(loader) + first_sample = next(loader_iter) + rank_state = loader.save_state_rank() + cmp_samples = [next(loader_iter) for _ in range(5)] + + restored_loader = build_loader() + restored_loader.restore_state_rank(rank_state) + restored_samples = [sample for _, sample in zip(range(5), restored_loader)] + + assert loader.can_restore_sample() + restored_first_sample = loader.restore_sample(first_sample.__restore_key__) + assert restored_first_sample.__key__ == first_sample.__key__ + assert torch.equal(restored_first_sample.caption, first_sample.caption) + assert [sample.__key__ for sample in restored_samples] == [ + sample.__key__ for sample in cmp_samples + ] + assert all( + torch.equal(restored.caption, cmp.caption) + for restored, cmp in zip(restored_samples, cmp_samples) + ) + assert_segments([first_sample, *cmp_samples]) + + final_packer_loader = get_loader( + get_train_dataset( + self.dataset_path, + batch_size=None, + packing_buffer_size=buffer_size, + worker_config=no_worker_config, + virtual_epoch_length=4, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=FinalPackerTaskEncoder(), + ) + ) + final_packer_samples = list(final_packer_loader) + assert final_packer_loader.can_restore_sample() + restored_final_packer_sample = final_packer_loader.restore_sample( + final_packer_samples[0].__restore_key__ + ) + assert restored_final_packer_sample.__key__ == final_packer_samples[0].__key__ + assert torch.equal( + restored_final_packer_sample.caption, + final_packer_samples[0].caption, + ) + assert_segments(final_packer_samples) + def test_group_batch(self): class GroupingTaskEncoder( TaskEncoder[CaptioningSample, CaptioningSample, CaptioningSample, CaptioningSample] diff --git a/tests/test_metadataset_v2.py b/tests/test_metadataset_v2.py index e33992f7..c37ff654 100644 --- a/tests/test_metadataset_v2.py +++ b/tests/test_metadataset_v2.py @@ -32,7 +32,7 @@ from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.metadataset.loader import prepare_metadataset, traverse_metadataset from megatron.energon.metadataset.loader_interface import DatasetBlendMode -from megatron.energon.task_encoder.base import DefaultTaskEncoder +from megatron.energon.task_encoder.base import DefaultTaskEncoder, stateless from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset from tests.epath_s3_emulator import setup_s3_emulator @@ -257,6 +257,86 @@ def test_metadataset(self): assert len(Counter(train_order1)) == 110 assert all(48 <= v <= 52 for v in Counter(train_order1).values()) + def test_group(self): + """Group-specific shuffle + packing keeps returned samples source-homogeneous.""" + mds_path = self.dataset_path / "group_blend.yaml" + with open(mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " path: ds1", + " group: alpha", + " subflavors:", + " packing_source: ds1", + " - weight: 1", + " path: ds2", + " group: beta", + " subflavors:", + " packing_source: ds2", + ] + ) + ) + + leaves = traverse_metadataset(mds_path, split_part="train") + assert len(leaves) == 2 + assert {ref.group for ref in leaves} == {"alpha", "beta"} + + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0, seed_offset=0) + loaded = load_dataset(mds_path).get_datasets( + training=True, + split_part="train", + worker_config=worker_config, + ) + assert loaded.blend_mode == DatasetBlendMode.DATASET_WEIGHT + assert {d.group for d in loaded.datasets} == {"alpha", "beta"} + + class GroupIsolationEncoder(DefaultTaskEncoder): + """Each returned packed sample must come from exactly one packing source.""" + + @stateless + def encode_sample(self, sample: TextSample) -> TextSample: + return sample + + def select_samples_to_pack(self, samples: list[TextSample]) -> list[list[TextSample]]: + return [samples] + + @stateless + def pack_selected_samples(self, samples: list[TextSample]) -> TextSample: + packing_sources = { + sample.__subflavors__.get("packing_source") + for sample in samples + if sample.__subflavors__ is not None + } + assert len(packing_sources) == 1, ( + "Mixed sources in one packed sample: " + f"{packing_sources} for keys {[sample.__key__ for sample in samples]}" + ) + return TextSample.derive_from( + samples[0], + __key__=",".join(s.__key__ for s in samples), + __restore_key__=(), + text=f"{next(iter(packing_sources))}:" + " | ".join(s.text for s in samples), + ) + + torch.manual_seed(42) + packed_ds = get_train_dataset( + mds_path, + worker_config=worker_config, + batch_size=2, + packing_buffer_size=8, + shuffle_buffer_size={"alpha": 8, "beta": None}, + max_samples_per_sequence=None, + task_encoder=GroupIsolationEncoder(), + virtual_epoch_length=10, + ) + list(get_loader(packed_ds)) + def test_nested_metadataset(self): torch.manual_seed(42) worker_config = WorkerConfig( @@ -375,9 +455,15 @@ def test_traverse_metadataset_preserves_missing_v2_leaf_and_aux(self): "splits:", " train:", " path: missing_ds", + " subflavors:", + " source: missing_leaf_metadataset_v2.yaml", + " number: 42", + " mds: nested_val", " aux:", " labels: missing_aux", " media: filesystem://media", + " group: abc", + " shuffle_over_epochs_multiplier: 2", ] ), encoding="utf-8", @@ -392,6 +478,13 @@ def test_traverse_metadataset_preserves_missing_v2_leaf_and_aux(self): "labels": EPath(self.dataset_path / "missing_aux"), "media": EPath(self.dataset_path / "media"), } + assert refs[0].subflavors == { + "source": "missing_leaf_metadataset_v2.yaml", + "number": 42, + "mds": "nested_val", + } + assert refs[0].group == "abc" + assert refs[0].shuffle_over_epochs_multiplier == 2 def test_joined_metadataset(self): torch.manual_seed(42)