import ast from pathlib import Path MMP_PARENTS = ["mmp"] + [f"a{i}" for i in range(1, 9)] MMP_LOCAL_MODULES = [ "anchor_grid", "annotation", "bbr", "dataset", "evallib", "label_grid", "main", "model", "nms", ] def check_bad_imports(file_path: Path | list[Path]): if not isinstance(file_path, list): file_path = [file_path] bad_imports = [] for path in file_path: with open(path, "r") as file: tree = ast.parse(file.read(), filename=path) absolute_imports = set() for node in ast.walk(tree): if isinstance(node, ast.Import): for n in node.names: for req in MMP_PARENTS: if req in n.name: absolute_imports.add(n.lineno) for req in MMP_LOCAL_MODULES: if req == n.name: absolute_imports.add(n.lineno) if isinstance(node, ast.ImportFrom): if node.level == 0: for req in MMP_PARENTS: if req in node.module: absolute_imports.add(node.lineno) for req in MMP_LOCAL_MODULES: if req == node.module: absolute_imports.add(node.lineno) if len(absolute_imports) > 0: bad_imports.append((path, sorted(absolute_imports))) if len(bad_imports) != 0: message = "\n There are absolute imports in the following files:\n" for path, line_numbers in bad_imports: message += f"{path}: {', '.join(str(num) for num in line_numbers)}\n" assert len(bad_imports) == 0, message