Created
December 13, 2023 08:28
-
-
Save chenyaofo/9675af9ee918f5c17e84bb49581f0d8e to your computer and use it in GitHub Desktop.
Main script from kubeedge/sedna-storage-initializer:v0.3.0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Copyright 2021 The KubeEdge Authors. | |
# Copyright 2020 kubeflow.org. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# modify from https://github.com/kubeflow/kfserving/blob/master/python/kfserving/kfserving/storage.py # noqa | |
import concurrent.futures | |
import glob | |
import gzip | |
import json | |
import logging | |
import mimetypes | |
import os | |
import re | |
import sys | |
import shutil | |
import tempfile | |
import tarfile | |
import zipfile | |
import minio | |
import requests | |
from urllib.parse import urlparse | |
_S3_PREFIX = "s3://" | |
_OBS_PREFIX = "obs://" | |
_LOCAL_PREFIX = "file://" | |
_URI_RE = "https?://(.+)/(.+)" | |
_HTTP_PREFIX = "http(s)://" | |
_HEADERS_SUFFIX = "-headers" | |
SUPPORT_PROTOCOLS = (_OBS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX) | |
LOG = logging.getLogger(__name__) | |
def setup_logger(): | |
format = '%(asctime)s %(levelname)s %(funcName)s:%(lineno)s] %(message)s' | |
logging.basicConfig(format=format) | |
LOG.setLevel(os.getenv('LOG_LEVEL', 'INFO')) | |
def _normalize_uri(uri: str) -> str: | |
for src, dst in [ | |
("/", _LOCAL_PREFIX), | |
(_OBS_PREFIX, _S3_PREFIX) | |
]: | |
if uri.startswith(src): | |
return uri.replace(src, dst, 1) | |
return uri | |
def download(uri: str, out_dir: str = None) -> str: | |
""" Download the uri to local directory. | |
Support procotols: http, s3. | |
Note when uri ends with .tar.gz/.tar/.zip, this will extract it | |
""" | |
LOG.info("Copying contents of %s to local %s", uri, out_dir) | |
uri = _normalize_uri(uri) | |
if not os.path.exists(out_dir): | |
os.makedirs(out_dir) | |
if uri.startswith(_S3_PREFIX): | |
download_s3(uri, out_dir) | |
elif uri.startswith(_LOCAL_PREFIX): | |
download_local(uri, out_dir) | |
elif re.search(_URI_RE, uri): | |
download_from_uri(uri, out_dir) | |
else: | |
raise Exception("Cannot recognize storage type for %s.\n" | |
"%r are the current available storage type." % | |
(uri, SUPPORT_PROTOCOLS)) | |
LOG.info("Successfully copied %s to %s", uri, out_dir) | |
return out_dir | |
def indirect_download(indirect_uri: str, out_dir: str = None) -> str: | |
""" Download the uri to local directory. | |
Support procotols: http, s3. | |
Note when uri ends with .tar.gz/.tar/.zip, this will extract it | |
""" | |
tmpdir = tempfile.mkdtemp() | |
download(indirect_uri, tmpdir) | |
files = os.listdir(tmpdir) | |
if len(files) != 1: | |
raise Exception("indirect url %s should be file, not directory" | |
% indirect_uri) | |
download_files = set() | |
with open(os.path.join(tmpdir, files[0])) as f: | |
base_uri = None | |
for line_no, line in enumerate(f): | |
line = line.strip() | |
if line.startswith('#'): | |
continue | |
if line: | |
if base_uri is None: | |
base_uri = line | |
else: | |
file_name = line | |
download_files.add(file_name) | |
if not download_files: | |
LOG.info("no files to download for indirect url %s", | |
indirect_uri) | |
return | |
if not os.path.exists(out_dir): | |
os.makedirs(out_dir) | |
LOG.info("To download %s files IN-DIRECT %s to %s", | |
len(download_files), indirect_uri, out_dir) | |
uri = _normalize_uri(base_uri) | |
# only support s3 for indirect download | |
if uri.startswith(_S3_PREFIX): | |
download_s3_with_multi_files(download_files, uri, out_dir) | |
else: | |
LOG.warning("unsupported %s for indirect url %s, skipped", | |
uri, indirect_uri) | |
return | |
LOG.info("Successfully download files IN-DIRECT %s to %s", | |
indirect_uri, out_dir) | |
return | |
def download_s3(uri, out_dir: str): | |
client = _create_minio_client() | |
count = _download_s3(client, uri, out_dir) | |
if count == 0: | |
raise RuntimeError("Failed to fetch files." | |
"The path %s does not exist." % (uri)) | |
LOG.info("downloaded %d files for %s.", count, uri) | |
def download_s3_with_multi_files(download_files, | |
base_uri, base_out_dir): | |
client = _create_minio_client() | |
total_count = 0 | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
todos = [] | |
for dfile in set(download_files): | |
dir_ = os.path.dirname(dfile) | |
uri = base_uri.rstrip("/") + "/" + dfile | |
out_dir = os.path.join(base_out_dir, dir_) | |
todos.append(executor.submit(_download_s3, client, uri, out_dir)) | |
for done in concurrent.futures.as_completed(todos): | |
count = done.result() | |
if count == 0: | |
LOG.warning("failed to download %s in base uri(%s)", | |
dfile, base_uri) | |
continue | |
total_count += count | |
LOG.info("downloaded %d files for base_uri %s to local dir %s.", | |
total_count, base_uri, base_out_dir) | |
def _download_s3(client, uri, out_dir): | |
bucket_args = uri.replace(_S3_PREFIX, "", 1).split("/", 1) | |
bucket_name = bucket_args[0] | |
bucket_path = len(bucket_args) > 1 and bucket_args[1] or "" | |
objects = client.list_objects(bucket_name, | |
prefix=bucket_path, | |
recursive=True, | |
use_api_v1=True) | |
count = 0 | |
for obj in objects: | |
# Replace any prefix from the object key with out_dir | |
subdir_object_key = obj.object_name[len(bucket_path):].strip("/") | |
# fget_object handles directory creation if does not exist | |
if not obj.is_dir: | |
local_file = os.path.join( | |
out_dir, | |
subdir_object_key or os.path.basename(obj.object_name) | |
) | |
LOG.debug("downloading count:%d, file:%s", | |
count, subdir_object_key) | |
client.fget_object(bucket_name, obj.object_name, local_file) | |
_extract_compress(local_file, out_dir) | |
count += 1 | |
return count | |
def download_local(uri, out_dir=None): | |
local_path = uri.replace(_LOCAL_PREFIX, "/", 1) | |
if not os.path.exists(local_path): | |
raise RuntimeError("Local path %s does not exist." % (uri)) | |
if out_dir is None: | |
return local_path | |
elif not os.path.isdir(out_dir): | |
os.makedirs(out_dir) | |
if os.path.isdir(local_path): | |
local_path = os.path.join(local_path, "*") | |
for src in glob.glob(local_path): | |
_, tail = os.path.split(src) | |
dest_path = os.path.join(out_dir, tail) | |
LOG.info("Linking: %s to %s", src, dest_path) | |
os.symlink(src, dest_path) | |
return out_dir | |
def download_from_uri(uri, out_dir=None): | |
url = urlparse(uri) | |
filename = os.path.basename(url.path) | |
mimetype, encoding = mimetypes.guess_type(url.path) | |
local_path = os.path.join(out_dir, filename) | |
if filename == '': | |
raise ValueError('No filename contained in URI: %s' % (uri)) | |
# Get header information from host url | |
headers = {} | |
host_uri = url.hostname | |
headers_json = os.getenv(host_uri + _HEADERS_SUFFIX, "{}") | |
headers = json.loads(headers_json) | |
with requests.get(uri, stream=True, headers=headers) as response: | |
if response.status_code != 200: | |
raise RuntimeError("URI: %s returned a %s response code." % | |
(uri, response.status_code)) | |
if encoding == 'gzip': | |
stream = gzip.GzipFile(fileobj=response.raw) | |
local_path = os.path.join(out_dir, f'{filename}.tar') | |
else: | |
stream = response.raw | |
with open(local_path, 'wb') as out: | |
shutil.copyfileobj(stream, out) | |
return _extract_compress(local_path, out_dir) | |
def _extract_compress(local_path, out_dir): | |
mimetype, encoding = mimetypes.guess_type(local_path) | |
if mimetype in ["application/x-tar", "application/zip"]: | |
if mimetype == "application/x-tar": | |
archive = tarfile.open(local_path, 'r', encoding='utf-8') | |
else: | |
archive = zipfile.ZipFile(local_path, 'r') | |
archive.extractall(out_dir) | |
archive.close() | |
os.remove(local_path) | |
return out_dir | |
def _create_minio_client(): | |
url = urlparse(os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com")) | |
use_ssl = url.scheme == 'https' if url.scheme else True | |
return minio.Minio( | |
url.netloc, | |
access_key=os.getenv("ACCESS_KEY_ID", ""), | |
secret_key=os.getenv("SECRET_ACCESS_KEY", ""), | |
secure=use_ssl | |
) | |
def main(): | |
setup_logger() | |
if len(sys.argv) < 2 or len(sys.argv) % 2 == 0: | |
LOG.error("Usage: download.py " | |
"src_uri dest_path [src_uri dest_path]") | |
sys.exit(1) | |
indirect_mark = os.getenv("INDIRECT_URL_MARK", "@") | |
for i in range(1, len(sys.argv)-1, 2): | |
src_uri = sys.argv[i] | |
dest_path = sys.argv[i+1] | |
LOG.info("Initializing, args: src_uri [%s] dest_path [%s]" % | |
(src_uri, dest_path)) | |
if dest_path.startswith(indirect_mark): | |
indirect_download(src_uri, dest_path[len(indirect_mark):]) | |
else: | |
download(src_uri, dest_path) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment