Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions test/test_copyright_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Comment thread
godobyte marked this conversation as resolved.

import re
from pathlib import Path

# For distributed open source and proprietary code, we must include a copyright header in source every file:
_copyright_header_re = re.compile(
r"Copyright Amazon\.com, Inc\. or its affiliates\. All Rights Reserved\.", re.IGNORECASE
)
_generated_by_scm = re.compile(
r"# file generated by (setuptools[_-]scm|vcs-versioning)", re.IGNORECASE
)


class FileMissingCopyRight(Exception):
filepath: Path

def __init__(self, message, filepath):
super().__init__(message)
self.filepath = filepath


def _check_file(filename: Path) -> None:
with open(filename, encoding="utf-8", errors="ignore") as infile:
lines_read = 0
for line in infile:
if _copyright_header_re.search(line):
return # success
lines_read += 1
if lines_read > 10:
raise FileMissingCopyRight(
f"Could not find a valid Amazon.com copyright header in the top of {filename}."
" Please add one.",
Path(filename),
)
# __init__.py files are usually empty, this is to catch that.
raise FileMissingCopyRight(
f"Could not find a valid Amazon.com copyright header in the top of {filename}."
" Please add one.",
Path(filename),
)


def _is_version_file(filename: Path) -> bool:
if filename.name != "_version.py":
return False
with open(filename, encoding="utf-8", errors="ignore") as infile:
lines_read: list[str] = []
for line in infile:
if _generated_by_scm.search(line):
return True
lines_read.append(line)
if len(lines_read) > 10:
break
print(f"_version.py file ({filename}) found that is not generated by scm:\n{lines_read}")
return False


def test_copyright_headers():
"""Verifies every .py file has an Amazon copyright header."""
root_project_dir = Path(__file__)
# The root of the project is the directory that contains the test directory.
while not (root_project_dir / "test").exists():
root_project_dir = root_project_dir.parent
# Choose only a few top level directories to test.
# That way we don't snag any virtual envs a developer might create, at the risk of missing
# some top level .py files.
top_level_dirs = ["src", "test", "scripts"]
file_count = 0
failed_files = set()
for top_level_dir in top_level_dirs:
for glob_pattern in ("**/*.py", "**/*.sh"):
for path in Path(root_project_dir / top_level_dir).glob(glob_pattern):
print(path)
if not _is_version_file(path):
try:
_check_file(path)
except FileMissingCopyRight as e:
failed_files.add(e.filepath)
file_count += 1

if failed_files:
formatted_failed_files = "\n\t".join(str(filepath) for filepath in failed_files)
raise Exception(
f"Found {len(failed_files)} files without copyright headers:\n\t{formatted_failed_files}"
)

print(f"test_copyright_headers checked {file_count} files successfully.")


if __name__ == "__main__":
test_copyright_headers()
Loading