Skip to content

Commit

Permalink
Add GCP downloader in example (#419)
Browse files Browse the repository at this point in the history
Make FT handler download from GCP (#417)

* Make FT handler download from GCP

* lint

* fix lint

---------

Co-authored-by: Ajay Saini <ajay@mosaicml.com>
Co-authored-by: Jeffrey Chen <jeffrey@mosaicml.com>
  • Loading branch information
3 people committed Jul 20, 2023
1 parent fb449f8 commit d1cd929
Showing 1 changed file with 38 additions and 9 deletions.
47 changes: 38 additions & 9 deletions examples/inference-deployments/mpt/mpt_ft_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

def download_convert(s3_path: Optional[str] = None,
hf_path: Optional[str] = None,
gcp_path: Optional[str] = None,
gpus: int = 1,
force_conversion: bool = False):
"""Download model and convert to FasterTransformer format.
Expand All @@ -35,16 +36,35 @@ def download_convert(s3_path: Optional[str] = None,
s3_path (str): Path for model location in an s3 bucket.
hf_path (str): Name of the model as on HF hub (e.g., mosaicml/mpt-7b-instruct) or local folder name containing
the model (e.g., mpt-7b-instruct)
gcp_path (str): Path for model location in a gcp bucket.
gpus (int): Number of gpus to use for inference (Default: 1)
force_conversion (bool): Force conversion to FT even if some features may not work as expected in FT (Default: False)
"""
if not s3_path and not hf_path:
if not s3_path and not gcp_path and not hf_path:
raise RuntimeError(
'Either s3_path or hf_path must be provided to download_convert')
'Either s3_path, gcp_path, or hf_path must be provided to download_convert'
)
model_name_or_path: str = ''

# If s3_path or gcp_path is provided, initialize the s3 client for download
s3 = None
download_from_path = None
if s3_path:
# s3 creds need to already be present as env vars
s3 = boto3.client('s3')
download_from_path = s3_path
if gcp_path:
s3 = boto3.client(
's3',
region_name='auto',
endpoint_url='https://storage.googleapis.com',
aws_access_key_id=os.environ['GCS_KEY'],
aws_secret_access_key=os.environ['GCS_SECRET'],
)
download_from_path = gcp_path

# If either s3_path or gcp_path is provided, download files
if s3:
model_name_or_path = LOCAL_MODEL_PATH

# Download model files
Expand All @@ -55,22 +75,31 @@ def download_convert(s3_path: Optional[str] = None,
else:
Path(LOCAL_MODEL_PATH).mkdir(parents=True, exist_ok=True)

print(f'Downloading model from path: {s3_path}')
print(f'Downloading model from path: {download_from_path}')

parsed_path = urlparse(s3_path)
parsed_path = urlparse(download_from_path)
prefix = parsed_path.path.lstrip('/') # type: ignore

objs = s3.list_objects_v2(
Bucket=parsed_path.netloc,
Prefix=parsed_path.path.lstrip('/'),
Prefix=prefix,
)
downloaded_file_set = set(os.listdir(LOCAL_MODEL_PATH))
for obj in objs['Contents']:
file_key = obj['Key']
try:
file_name = os.path.basename(file_key)
s3.download_file(Bucket=parsed_path.netloc,
Key=file_key,
Filename=os.path.join(
LOCAL_MODEL_PATH, file_name))
if not file_name or file_name.startswith('.'):
# Ignore hidden files
continue
if file_name not in downloaded_file_set:
print(
f'Downloading {os.path.join(LOCAL_MODEL_PATH, file_name)}...'
)
s3.download_file(Bucket=parsed_path.netloc,
Key=file_key,
Filename=os.path.join(
LOCAL_MODEL_PATH, file_name))
except botocore.exceptions.ClientError as e:
print(
f'Error downloading file with key: {file_key} with error: {e}'
Expand Down

0 comments on commit d1cd929

Please sign in to comment.