Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Created December 13, 2023 08:28
Show Gist options
  • Save chenyaofo/9675af9ee918f5c17e84bb49581f0d8e to your computer and use it in GitHub Desktop.
Save chenyaofo/9675af9ee918f5c17e84bb49581f0d8e to your computer and use it in GitHub Desktop.
Main script from kubeedge/sedna-storage-initializer:v0.3.0
#!/usr/bin/env python3
# Copyright 2021 The KubeEdge Authors.
# Copyright 2020 kubeflow.org.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# modify from https://github.com/kubeflow/kfserving/blob/master/python/kfserving/kfserving/storage.py # noqa
import concurrent.futures
import glob
import gzip
import json
import logging
import mimetypes
import os
import re
import sys
import shutil
import tempfile
import tarfile
import zipfile
import minio
import requests
from urllib.parse import urlparse
_S3_PREFIX = "s3://"
_OBS_PREFIX = "obs://"
_LOCAL_PREFIX = "file://"
_URI_RE = "https?://(.+)/(.+)"
_HTTP_PREFIX = "http(s)://"
_HEADERS_SUFFIX = "-headers"
SUPPORT_PROTOCOLS = (_OBS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX)
LOG = logging.getLogger(__name__)
def setup_logger():
format = '%(asctime)s %(levelname)s %(funcName)s:%(lineno)s] %(message)s'
logging.basicConfig(format=format)
LOG.setLevel(os.getenv('LOG_LEVEL', 'INFO'))
def _normalize_uri(uri: str) -> str:
for src, dst in [
("/", _LOCAL_PREFIX),
(_OBS_PREFIX, _S3_PREFIX)
]:
if uri.startswith(src):
return uri.replace(src, dst, 1)
return uri
def download(uri: str, out_dir: str = None) -> str:
""" Download the uri to local directory.
Support procotols: http, s3.
Note when uri ends with .tar.gz/.tar/.zip, this will extract it
"""
LOG.info("Copying contents of %s to local %s", uri, out_dir)
uri = _normalize_uri(uri)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
if uri.startswith(_S3_PREFIX):
download_s3(uri, out_dir)
elif uri.startswith(_LOCAL_PREFIX):
download_local(uri, out_dir)
elif re.search(_URI_RE, uri):
download_from_uri(uri, out_dir)
else:
raise Exception("Cannot recognize storage type for %s.\n"
"%r are the current available storage type." %
(uri, SUPPORT_PROTOCOLS))
LOG.info("Successfully copied %s to %s", uri, out_dir)
return out_dir
def indirect_download(indirect_uri: str, out_dir: str = None) -> str:
""" Download the uri to local directory.
Support procotols: http, s3.
Note when uri ends with .tar.gz/.tar/.zip, this will extract it
"""
tmpdir = tempfile.mkdtemp()
download(indirect_uri, tmpdir)
files = os.listdir(tmpdir)
if len(files) != 1:
raise Exception("indirect url %s should be file, not directory"
% indirect_uri)
download_files = set()
with open(os.path.join(tmpdir, files[0])) as f:
base_uri = None
for line_no, line in enumerate(f):
line = line.strip()
if line.startswith('#'):
continue
if line:
if base_uri is None:
base_uri = line
else:
file_name = line
download_files.add(file_name)
if not download_files:
LOG.info("no files to download for indirect url %s",
indirect_uri)
return
if not os.path.exists(out_dir):
os.makedirs(out_dir)
LOG.info("To download %s files IN-DIRECT %s to %s",
len(download_files), indirect_uri, out_dir)
uri = _normalize_uri(base_uri)
# only support s3 for indirect download
if uri.startswith(_S3_PREFIX):
download_s3_with_multi_files(download_files, uri, out_dir)
else:
LOG.warning("unsupported %s for indirect url %s, skipped",
uri, indirect_uri)
return
LOG.info("Successfully download files IN-DIRECT %s to %s",
indirect_uri, out_dir)
return
def download_s3(uri, out_dir: str):
client = _create_minio_client()
count = _download_s3(client, uri, out_dir)
if count == 0:
raise RuntimeError("Failed to fetch files."
"The path %s does not exist." % (uri))
LOG.info("downloaded %d files for %s.", count, uri)
def download_s3_with_multi_files(download_files,
base_uri, base_out_dir):
client = _create_minio_client()
total_count = 0
with concurrent.futures.ThreadPoolExecutor() as executor:
todos = []
for dfile in set(download_files):
dir_ = os.path.dirname(dfile)
uri = base_uri.rstrip("/") + "/" + dfile
out_dir = os.path.join(base_out_dir, dir_)
todos.append(executor.submit(_download_s3, client, uri, out_dir))
for done in concurrent.futures.as_completed(todos):
count = done.result()
if count == 0:
LOG.warning("failed to download %s in base uri(%s)",
dfile, base_uri)
continue
total_count += count
LOG.info("downloaded %d files for base_uri %s to local dir %s.",
total_count, base_uri, base_out_dir)
def _download_s3(client, uri, out_dir):
bucket_args = uri.replace(_S3_PREFIX, "", 1).split("/", 1)
bucket_name = bucket_args[0]
bucket_path = len(bucket_args) > 1 and bucket_args[1] or ""
objects = client.list_objects(bucket_name,
prefix=bucket_path,
recursive=True,
use_api_v1=True)
count = 0
for obj in objects:
# Replace any prefix from the object key with out_dir
subdir_object_key = obj.object_name[len(bucket_path):].strip("/")
# fget_object handles directory creation if does not exist
if not obj.is_dir:
local_file = os.path.join(
out_dir,
subdir_object_key or os.path.basename(obj.object_name)
)
LOG.debug("downloading count:%d, file:%s",
count, subdir_object_key)
client.fget_object(bucket_name, obj.object_name, local_file)
_extract_compress(local_file, out_dir)
count += 1
return count
def download_local(uri, out_dir=None):
local_path = uri.replace(_LOCAL_PREFIX, "/", 1)
if not os.path.exists(local_path):
raise RuntimeError("Local path %s does not exist." % (uri))
if out_dir is None:
return local_path
elif not os.path.isdir(out_dir):
os.makedirs(out_dir)
if os.path.isdir(local_path):
local_path = os.path.join(local_path, "*")
for src in glob.glob(local_path):
_, tail = os.path.split(src)
dest_path = os.path.join(out_dir, tail)
LOG.info("Linking: %s to %s", src, dest_path)
os.symlink(src, dest_path)
return out_dir
def download_from_uri(uri, out_dir=None):
url = urlparse(uri)
filename = os.path.basename(url.path)
mimetype, encoding = mimetypes.guess_type(url.path)
local_path = os.path.join(out_dir, filename)
if filename == '':
raise ValueError('No filename contained in URI: %s' % (uri))
# Get header information from host url
headers = {}
host_uri = url.hostname
headers_json = os.getenv(host_uri + _HEADERS_SUFFIX, "{}")
headers = json.loads(headers_json)
with requests.get(uri, stream=True, headers=headers) as response:
if response.status_code != 200:
raise RuntimeError("URI: %s returned a %s response code." %
(uri, response.status_code))
if encoding == 'gzip':
stream = gzip.GzipFile(fileobj=response.raw)
local_path = os.path.join(out_dir, f'{filename}.tar')
else:
stream = response.raw
with open(local_path, 'wb') as out:
shutil.copyfileobj(stream, out)
return _extract_compress(local_path, out_dir)
def _extract_compress(local_path, out_dir):
mimetype, encoding = mimetypes.guess_type(local_path)
if mimetype in ["application/x-tar", "application/zip"]:
if mimetype == "application/x-tar":
archive = tarfile.open(local_path, 'r', encoding='utf-8')
else:
archive = zipfile.ZipFile(local_path, 'r')
archive.extractall(out_dir)
archive.close()
os.remove(local_path)
return out_dir
def _create_minio_client():
url = urlparse(os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com"))
use_ssl = url.scheme == 'https' if url.scheme else True
return minio.Minio(
url.netloc,
access_key=os.getenv("ACCESS_KEY_ID", ""),
secret_key=os.getenv("SECRET_ACCESS_KEY", ""),
secure=use_ssl
)
def main():
setup_logger()
if len(sys.argv) < 2 or len(sys.argv) % 2 == 0:
LOG.error("Usage: download.py "
"src_uri dest_path [src_uri dest_path]")
sys.exit(1)
indirect_mark = os.getenv("INDIRECT_URL_MARK", "@")
for i in range(1, len(sys.argv)-1, 2):
src_uri = sys.argv[i]
dest_path = sys.argv[i+1]
LOG.info("Initializing, args: src_uri [%s] dest_path [%s]" %
(src_uri, dest_path))
if dest_path.startswith(indirect_mark):
indirect_download(src_uri, dest_path[len(indirect_mark):])
else:
download(src_uri, dest_path)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment