-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathsetup.py
134 lines (110 loc) · 3.79 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0
import atexit
import glob
import os
import shutil
import sys
import tempfile
from Cython.Build import cythonize
from setuptools import setup, Extension, find_packages
from packaging.version import Version
import Cython
# Check Cython version
cython_version = Version(Cython.__version__)
# this is tricky: sys.path gets overwritten at different stages of the build
# flow, so we need to hack sys.path ourselves...
source_root = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(source_root, "builder"))
import utils # type: ignore # this is builder.utils # noqa: E402
# List the main modules, and infer the auxiliary modules automatically
ext_modules = [
"nvmath.bindings.cublas",
"nvmath.bindings.cublasLt",
"nvmath.bindings.cusolver",
"nvmath.bindings.cusolverDn",
"nvmath.bindings.cufft",
"nvmath.bindings.cusparse",
"nvmath.bindings.curand",
]
if sys.platform == "linux":
ext_modules.append("nvmath.bindings.nvpl.fft")
# WAR: Check if this is still valid
# TODO: can this support cross-compilation?
if sys.platform == "linux":
src_files = glob.glob("*/bindings/**/_internal/*_linux.pyx", recursive=True)
elif sys.platform == "win32":
src_files = glob.glob("*/bindings/**/_internal/*_windows.pyx", recursive=True)
else:
raise RuntimeError(f"platform is unrecognized: {sys.platform}")
dst_files = []
for src in src_files:
# Set up a temporary file; it must be under the cache directory so
# that atomic moves within the same filesystem can be guaranteed
with tempfile.NamedTemporaryFile(delete=False, dir=".") as f:
shutil.copy2(src, f.name)
f_name = f.name
dst = src.replace("_linux", "").replace("_windows", "")
# atomic move with the destination guaranteed to be overwritten
os.replace(f_name, f"./{dst}")
dst_files.append(dst)
@atexit.register
def cleanup_dst_files():
for dst in dst_files:
try:
os.remove(dst)
except FileNotFoundError:
pass
def calculate_modules(module):
module = module.split(".")
lowpp_mod = module.copy()
lowpp_mod_pyx = os.path.join(*module[:-1], f"{module[-1]}.pyx")
lowpp_mod = ".".join(lowpp_mod)
lowpp_ext = Extension(
lowpp_mod,
sources=[lowpp_mod_pyx],
language="c++",
)
cy_mod = module.copy()
cy_mod[-1] = f"cy{cy_mod[-1]}"
cy_mod_pyx = os.path.join(*cy_mod[:-1], f"{cy_mod[-1]}.pyx")
cy_mod = ".".join(cy_mod)
cy_ext = Extension(
cy_mod,
sources=[cy_mod_pyx],
language="c++",
)
inter_mod = module.copy()
inter_mod.insert(-1, "_internal")
inter_mod_pyx = os.path.join(*inter_mod[:-1], f"{inter_mod[-1]}.pyx")
inter_mod = ".".join(inter_mod)
inter_ext = Extension(
inter_mod,
sources=[inter_mod_pyx],
language="c++",
)
return lowpp_ext, cy_ext, inter_ext
# Note: the extension attributes are overwritten in build_extension()
ext_modules = [e for ext in ext_modules for e in calculate_modules(ext)] + [
Extension(
"nvmath.bindings._internal.utils",
sources=["nvmath/bindings/_internal/utils.pyx"],
language="c++",
),
]
cmdclass = {
"build_ext": utils.build_ext,
"bdist_wheel": utils.bdist_wheel,
}
# Choose the appropriate setup function based on Cython version
if cython_version.major >= 3:
compiler_directives = {"embedsignature": True, "show_performance_hints": False}
else:
compiler_directives = {"embedsignature": True}
setup(
ext_modules=cythonize(ext_modules, verbose=True, language_level=3, compiler_directives=compiler_directives),
packages=find_packages(include=["nvmath", "nvmath.*"]),
zip_safe=False,
cmdclass=cmdclass,
)