Last active
August 21, 2020 03:41
-
-
Save sritchie/8ea060a32c3696a7b98a26c57f290635 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from setuptools import find_packages, setup | |
# This follows the style of Jaxlib installation here: | |
# https://github.com/google/jax#pip-installation | |
PYTHON_VERSION = "cp37" | |
CUDA_VERSION = "cuda101" # alternatives: cuda90, cuda92, cuda100, cuda101 | |
PLATFORM = "linux_x86_64" # alternatives: linux_x86_64 | |
BASE_URL = "https://storage.googleapis.com/jax-releases" | |
def jax_artifact(version, gpu=False): | |
if gpu: | |
prefix = f"{BASE_URL}/{CUDA_VERSION}/jaxlib" | |
wheel_suffix = f"{PYTHON_VERSION}-none-{PLATFORM}.whl" | |
location = f"{prefix}-{version}-{wheel_suffix}" | |
return f"jaxlib @ {location}" | |
return f"jaxlib=={version}" | |
def readme(): | |
try: | |
with open('README.md') as rf: | |
return rf.read() | |
except FileNotFoundError: | |
return None | |
JAXLIB_VERSION = "0.1.43" | |
JAX_VERSION = "0.1.62" | |
REQUIRED_PACKAGES = [ | |
"pg8000>=1.16.1" | |
"uv-metrics>=0.4.2", | |
"fs", | |
"fs-gcsfs", | |
f"jax=={JAX_VERSION}", | |
] | |
setup( | |
name='my_project', | |
version="0.0.1", | |
cmdclass=with_versioneer(lambda v: v.get_cmdclass(), {}), | |
description='Getting it done.', | |
long_description=readme(), | |
author='Sam Ritchie', | |
author_email='[email protected]', | |
url='https://github.com/google/caliban', | |
packages=find_packages(exclude=('tests', 'docs')), | |
install_requires=REQUIRED_PACKAGES, | |
extras_require={ | |
"cpu": [jax_artifact(JAXLIB_VERSION, gpu=False)], | |
"gpu": [jax_artifact(JAXLIB_VERSION, gpu=True)], | |
}, | |
include_package_data=True, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment