diff --git a/examples/inference-deployments/mpt/mpt_ft_handler.py b/examples/inference-deployments/mpt/mpt_ft_handler.py index 570e9de02..2325eb9e7 100644 --- a/examples/inference-deployments/mpt/mpt_ft_handler.py +++ b/examples/inference-deployments/mpt/mpt_ft_handler.py @@ -27,6 +27,7 @@ def download_convert(s3_path: Optional[str] = None, hf_path: Optional[str] = None, + gcp_path: Optional[str] = None, gpus: int = 1, force_conversion: bool = False): """Download model and convert to FasterTransformer format. @@ -35,16 +36,35 @@ def download_convert(s3_path: Optional[str] = None, s3_path (str): Path for model location in an s3 bucket. hf_path (str): Name of the model as on HF hub (e.g., mosaicml/mpt-7b-instruct) or local folder name containing the model (e.g., mpt-7b-instruct) + gcp_path (str): Path for model location in a gcp bucket. gpus (int): Number of gpus to use for inference (Default: 1) force_conversion (bool): Force conversion to FT even if some features may not work as expected in FT (Default: False) """ - if not s3_path and not hf_path: + if not s3_path and not gcp_path and not hf_path: raise RuntimeError( - 'Either s3_path or hf_path must be provided to download_convert') + 'Either s3_path, gcp_path, or hf_path must be provided to download_convert' + ) model_name_or_path: str = '' + + # If s3_path or gcp_path is provided, initialize the s3 client for download + s3 = None + download_from_path = None if s3_path: # s3 creds need to already be present as env vars s3 = boto3.client('s3') + download_from_path = s3_path + if gcp_path: + s3 = boto3.client( + 's3', + region_name='auto', + endpoint_url='https://storage.googleapis.com', + aws_access_key_id=os.environ['GCS_KEY'], + aws_secret_access_key=os.environ['GCS_SECRET'], + ) + download_from_path = gcp_path + + # If either s3_path or gcp_path is provided, download files + if s3: model_name_or_path = LOCAL_MODEL_PATH # Download model files @@ -55,22 +75,31 @@ def download_convert(s3_path: Optional[str] = None, else: Path(LOCAL_MODEL_PATH).mkdir(parents=True, exist_ok=True) - print(f'Downloading model from path: {s3_path}') + print(f'Downloading model from path: {download_from_path}') - parsed_path = urlparse(s3_path) + parsed_path = urlparse(download_from_path) + prefix = parsed_path.path.lstrip('/') # type: ignore objs = s3.list_objects_v2( Bucket=parsed_path.netloc, - Prefix=parsed_path.path.lstrip('/'), + Prefix=prefix, ) + downloaded_file_set = set(os.listdir(LOCAL_MODEL_PATH)) for obj in objs['Contents']: file_key = obj['Key'] try: file_name = os.path.basename(file_key) - s3.download_file(Bucket=parsed_path.netloc, - Key=file_key, - Filename=os.path.join( - LOCAL_MODEL_PATH, file_name)) + if not file_name or file_name.startswith('.'): + # Ignore hidden files + continue + if file_name not in downloaded_file_set: + print( + f'Downloading {os.path.join(LOCAL_MODEL_PATH, file_name)}...' + ) + s3.download_file(Bucket=parsed_path.netloc, + Key=file_key, + Filename=os.path.join( + LOCAL_MODEL_PATH, file_name)) except botocore.exceptions.ClientError as e: print( f'Error downloading file with key: {file_key} with error: {e}'