Skip to content

Instantly share code, notes, and snippets.

@sritchie
Last active August 21, 2020 03:41
Show Gist options
  • Save sritchie/8ea060a32c3696a7b98a26c57f290635 to your computer and use it in GitHub Desktop.
Save sritchie/8ea060a32c3696a7b98a26c57f290635 to your computer and use it in GitHub Desktop.
from setuptools import find_packages, setup
# This follows the style of Jaxlib installation here:
# https://github.com/google/jax#pip-installation
PYTHON_VERSION = "cp37"
CUDA_VERSION = "cuda101" # alternatives: cuda90, cuda92, cuda100, cuda101
PLATFORM = "linux_x86_64" # alternatives: linux_x86_64
BASE_URL = "https://storage.googleapis.com/jax-releases"
def jax_artifact(version, gpu=False):
if gpu:
prefix = f"{BASE_URL}/{CUDA_VERSION}/jaxlib"
wheel_suffix = f"{PYTHON_VERSION}-none-{PLATFORM}.whl"
location = f"{prefix}-{version}-{wheel_suffix}"
return f"jaxlib @ {location}"
return f"jaxlib=={version}"
def readme():
try:
with open('README.md') as rf:
return rf.read()
except FileNotFoundError:
return None
JAXLIB_VERSION = "0.1.43"
JAX_VERSION = "0.1.62"
REQUIRED_PACKAGES = [
"pg8000>=1.16.1"
"uv-metrics>=0.4.2",
"fs",
"fs-gcsfs",
f"jax=={JAX_VERSION}",
]
setup(
name='my_project',
version="0.0.1",
cmdclass=with_versioneer(lambda v: v.get_cmdclass(), {}),
description='Getting it done.',
long_description=readme(),
author='Sam Ritchie',
author_email='[email protected]',
url='https://github.com/google/caliban',
packages=find_packages(exclude=('tests', 'docs')),
install_requires=REQUIRED_PACKAGES,
extras_require={
"cpu": [jax_artifact(JAXLIB_VERSION, gpu=False)],
"gpu": [jax_artifact(JAXLIB_VERSION, gpu=True)],
},
include_package_data=True,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment