-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathsetup.py
More file actions
116 lines (96 loc) · 3.74 KB
/
setup.py
File metadata and controls
116 lines (96 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import sys
import subprocess
from setuptools import setup, find_packages
def fetch_requirements(path):
with open(path, 'r') as fd:
return [r.strip() for r in fd.readlines()]
install_requires = fetch_requirements('requirements/requirements.txt')
extras_require = {
"dev": fetch_requirements('requirements/requirements-dev.txt')
}
def command_exists(cmd):
if sys.platform == "win32":
result = subprocess.Popen(f'{cmd}', stdout=subprocess.PIPE, shell=True)
return result.wait() == 1
else:
result = subprocess.Popen(f'type {cmd}',
stdout=subprocess.PIPE,
shell=True)
return result.wait() == 0
# Write out version/git info
git_hash_cmd = "git rev-parse --short HEAD"
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
if command_exists('git') and 'DS_KERNELS_BUILD_STRING' not in os.environ:
try:
result = subprocess.check_output(git_hash_cmd, shell=True)
git_hash = result.decode('utf-8').strip()
result = subprocess.check_output(git_branch_cmd, shell=True)
git_branch = result.decode('utf-8').strip()
except subprocess.CalledProcessError:
git_hash = "unknown"
git_branch = "unknown"
else:
git_hash = "unknown"
git_branch = "unknown"
# Ensure all submodules have been pulled in
git_submodules = "git submodule update --init --recursive"
if command_exists('git'):
try:
result = subprocess.check_output(git_submodules, shell=True)
except subprocess.CalledProcessError:
pass
# Parse the ds-kernels version string from version.txt
version_str = open('version.txt', 'r').read().strip()
# Build specifiers like .devX can be added at install time. Otherwise, add the git hash.
# example: BUILD_STR=".dev20201022" python -m build --sdist --wheel
BUILD_STRING = 'DS_KERNELS_BUILD_STRING'
BUILD_FILE = 'build.txt'
build_string = os.environ.get(BUILD_STRING)
# Building wheel for distribution, update version file
if build_string:
# Build string env specified, probably building for distribution
with open(BUILD_FILE, 'w') as fd:
fd.write(build_string)
version_str += build_string
elif os.path.isfile(BUILD_FILE):
# build.txt exists, probably installing from distribution
with open(BUILD_FILE, 'r') as fd:
version_str += fd.read().strip()
else:
# None of the above, probably installing from source
version_str += f'+{git_hash}'
# write out installed version
with open("dskernels/version.py", 'w') as fd:
fd.write(f"__version__ = '{version_str}'\n")
from builder.builder import CMakeBuild
from builder.ft_gemm import FTGemmBuilder
from builder.inf_flash_attn import BlockedFlashBuilder
ext_modules = []
build_ext = {'build_ext': CMakeBuild}
ext_modules.append(FTGemmBuilder(name="deepspeed_ft_gemm"))
ext_modules.append(BlockedFlashBuilder(name="deepspeed_blocked_flash"))
setup(name="deepspeed-kernels",
version=version_str,
description='deepspeed kernels',
author='DeepSpeed Team',
author_email='deepspeed@microsoft.com',
url='http://deepspeed.ai',
project_urls={
'Documentation': 'https://github.com/deepspeedai/DeepSpeed-Kernels',
'Source': 'https://github.com/deepspeedai/DeepSpeed-Kernels',
},
install_requires=install_requires,
extras_require=extras_require,
ext_modules=ext_modules,
cmdclass=build_ext,
include_package_data=True,
packages=find_packages(include=['dskernels']),
classifiers=[
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11'
])