Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/megatron/energon/epathlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from megatron.energon.epathlib.epath import EPath
from megatron.energon.epathlib.epath import DEFAULT_PROFILE_NAME, EPath

__all__ = ["EPath"]
__all__ = ["DEFAULT_PROFILE_NAME", "EPath"]
13 changes: 10 additions & 3 deletions src/megatron/energon/epathlib/epath.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@

from megatron.energon.deprecation import warn_deprecated

try:
from multistorageclient.config import RESERVED_POSIX_PROFILE_NAME

DEFAULT_PROFILE_NAME = RESERVED_POSIX_PROFILE_NAME
except ImportError:
DEFAULT_PROFILE_NAME = "default"


class EPath:
"""EPath work much like pathlib.Path but does support local as well as remote file systems.
Expand Down Expand Up @@ -42,11 +49,11 @@ def __init__(
else:
if isinstance(initial_path, PathlibPath):
path = str(initial_path.absolute())
profile = "default"
profile = DEFAULT_PROFILE_NAME
else:
protocol, profile, path = self._split_protocol(initial_path)
if protocol is None or protocol == "file":
profile = "default"
profile = DEFAULT_PROFILE_NAME
path = str(PathlibPath(path).absolute())
elif protocol == "rclone":
warn_deprecated("rclone:// protocol is deprecated. Use msc:// instead.")
Expand Down Expand Up @@ -188,7 +195,7 @@ def url(self) -> str:
return f"msc://{self.profile}{int_path_str}"

def is_local(self) -> bool:
return self.profile == "default"
return self.profile == DEFAULT_PROFILE_NAME

def is_dir(self) -> bool:
try:
Expand Down
66 changes: 53 additions & 13 deletions tests/test_av_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import time
import unittest
from dataclasses import fields
from pathlib import Path

import av
Expand Down Expand Up @@ -79,6 +80,33 @@ def tensors_close(tensor1: torch.Tensor, tensor2: torch.Tensor, tolerance: float
return mae <= tolerance


def avmetadata_equal(a: AVMetadata, b: AVMetadata, *, ndigits: int = 3) -> bool:
"""Compare two AVMetadata instances.

Float fields are compared after rounding to ``ndigits`` decimals.
"""

if a is b:
return True
if type(a) is not type(b):
return False

for f in fields(a):
av = getattr(a, f.name)
bv = getattr(b, f.name)
if isinstance(av, (float, np.floating)) or isinstance(bv, (float, np.floating)):
if av is None or bv is None:
if av != bv:
return False
else:
if round(float(av), ndigits) != round(float(bv), ndigits):
return False
else:
if av != bv:
return False
return True


class TestFastseek(unittest.TestCase):
"""Test fastseek functionality."""

Expand Down Expand Up @@ -495,7 +523,7 @@ def test_decode_metadata(self):
audio_num_samples=3028992,
),
AVMetadata(
video_duration=63.03333333333333,
video_duration=63.033,
video_num_frames=1891,
video_fps=30.0,
video_width=192,
Expand All @@ -510,20 +538,30 @@ def test_decode_metadata(self):
["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"], expected_metadata
):
av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes()))
assert av_decoder.get_metadata(get_audio_num_samples=True) == expected_metadata, (
f"Metadata does not match expected metadata for {video_file}"

actual_metadata = av_decoder.get_metadata(get_audio_num_samples=True)
assert avmetadata_equal(actual_metadata, expected_metadata), (
f"Metadata does not match expected metadata for {video_file}: "
f"{actual_metadata} != {expected_metadata}"
)

assert av_decoder.get_video_duration(get_frame_count=False) in (
(expected_metadata.video_duration, None),
(expected_metadata.video_duration, expected_metadata.video_num_frames),
getvid_duration, getvid_frame_count = av_decoder.get_video_duration(
get_frame_count=False
)
assert av_decoder.get_video_duration(get_frame_count=True) == (
expected_metadata.video_duration,
expected_metadata.video_num_frames,
self.assertAlmostEqual(getvid_duration, expected_metadata.video_duration, places=3)

if getvid_frame_count is not None:
self.assertEqual(getvid_frame_count, expected_metadata.video_num_frames)

getvid_duration, getvid_frame_count = av_decoder.get_video_duration(
get_frame_count=True
)
self.assertAlmostEqual(getvid_duration, expected_metadata.video_duration, places=3)
self.assertEqual(getvid_frame_count, expected_metadata.video_num_frames)

assert av_decoder.get_audio_duration() == expected_metadata.audio_duration
self.assertAlmostEqual(
av_decoder.get_audio_duration(), expected_metadata.audio_duration, places=3
)
assert av_decoder.get_video_fps() == expected_metadata.video_fps
assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate

Expand Down Expand Up @@ -657,7 +695,7 @@ def test_pickle_decoder(self):

# Verify metadata matches
unpickled_metadata = unpickled_decoder.get_metadata()
assert unpickled_metadata == original_metadata, (
assert avmetadata_equal(unpickled_metadata, original_metadata), (
f"Unpickled metadata {unpickled_metadata} does not match original {original_metadata}"
)

Expand Down Expand Up @@ -900,8 +938,10 @@ def test_decode_metadata(self):
["tests/data/test_audio.flac", "tests/data/test_audio.wav"], expected_metadata
):
av_decoder = AVDecoder(io.BytesIO(Path(audio_file).read_bytes()))
assert av_decoder.get_metadata(get_audio_num_samples=True) == expected_metadata, (
f"Metadata does not match expected metadata for {audio_file}: {av_decoder.get_metadata()}"
actual_metadata = av_decoder.get_metadata(get_audio_num_samples=True)
assert avmetadata_equal(actual_metadata, expected_metadata), (
f"Metadata does not match expected metadata for {audio_file}: "
f"{actual_metadata} != {expected_metadata}"
)

assert av_decoder.get_audio_duration() == expected_metadata.audio_duration
Expand Down
3 changes: 2 additions & 1 deletion tests/test_crudedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def cook_text(sample: dict) -> TextSample:

@stateless
def cook_other(sample: dict) -> TextSample:
d = pickle.loads(sample["pkl"])
pkl_value = sample["pkl"]
d = pickle.loads(pkl_value) if isinstance(pkl_value, (bytes, str)) else pkl_value
return TextSample(
**basic_sample_keys(sample),
text=f"<{sample['txt']}|{d['idx']}>",
Expand Down
12 changes: 6 additions & 6 deletions tests/test_epathlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
import unittest

from megatron.energon.epathlib import EPath
from megatron.energon.epathlib import DEFAULT_PROFILE_NAME, EPath
from tests.epath_s3_emulator import setup_s3_emulator


Expand Down Expand Up @@ -174,15 +174,15 @@ def test_s3_path_resolution(self):
def test_multi_storage_client(self):
"""Test the Multi-Storage Client integration"""
# Test path handling
p = EPath("msc://default/etc/resolv.conf")
p = EPath(f"msc://{DEFAULT_PROFILE_NAME}/etc/resolv.conf")
assert str(p) == "/etc/resolv.conf", str(p)
assert p.is_file()

p2 = p / ".." / "hosts"
assert str(p2) == "/etc/hosts", str(p2)

# Test glob
p3 = EPath("msc://default/etc/")
p3 = EPath(f"msc://{DEFAULT_PROFILE_NAME}/etc/")
assert p3.is_dir()
for i in p3.glob("*.conf"):
assert str(i).endswith(".conf")
Expand All @@ -193,13 +193,13 @@ def test_multi_storage_client(self):
assert len(fp.read()) > 0

# Test move and delete
p4 = EPath("msc://default/tmp/random_file_0001")
p4 = EPath(f"msc://{DEFAULT_PROFILE_NAME}/tmp/random_file_0001")
if p4.is_file():
p4.unlink()
with p4.open("w") as fp:
fp.write("*****")
assert p4.is_file()
p5 = EPath("msc://default/tmp/random_file_0002")
p5 = EPath(f"msc://{DEFAULT_PROFILE_NAME}/tmp/random_file_0002")
if p5.is_file():
p5.unlink()
assert p5.is_file() is False
Expand Down Expand Up @@ -239,7 +239,7 @@ def test_multiprocessing(self):

def test_multiprocessing_msc(self):
"""Test EPath in multiprocessing context"""
p = EPath("msc://default/tmp/random_file_0001")
p = EPath(f"msc://{DEFAULT_PROFILE_NAME}/tmp/random_file_0001")
with p.open("w") as fp:
fp.write("*****")

Expand Down
Loading