diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index dc6fffc4b..6615bbd9f 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -5,6 +5,7 @@ import json import logging import shlex +import os import sys import warnings from collections import OrderedDict @@ -1100,7 +1101,7 @@ def list_files(self): files = [UserFile(file['path'], connection=self, metadata=file) for file in files] return VisualList("data-table", data=files, parameters={'columns': 'files'}) - def get_file(self, path) -> UserFile: + def get_file(self, path: str) -> UserFile: """ Gets a (virtual) file for the user workspace @@ -1108,6 +1109,15 @@ def get_file(self, path) -> UserFile: """ return UserFile(path, connection=self) + def upload_file(self, source: Union[Path, str]) -> UserFile: + """ + Uploads a file to the user workspace and stores it with the filename of the source given. + If a file with the name exists on the user workspace it will be replaced. + + :return: UserFile object. + """ + return self.get_file(os.path.basename(source)).upload(source) + def _build_request_with_process_graph(self, process_graph: Union[dict, Any], **kwargs) -> dict: """ Prepare a json payload with a process graph to submit to /result, /services, /jobs, ... diff --git a/openeo/rest/userfile.py b/openeo/rest/userfile.py index bda53952f..4507c6706 100644 --- a/openeo/rest/userfile.py +++ b/openeo/rest/userfile.py @@ -1,6 +1,8 @@ import typing from typing import Any, Dict, List, Union +import os from pathlib import Path +from util import ensure_dir if typing.TYPE_CHECKING: # Imports for type checking only (circular import issue at runtime). @@ -10,9 +12,9 @@ class UserFile: """Represents a file in the user-workspace of openeo.""" - def __init__(self, path: str, connection: 'Connection', metadata: Dict[str, Any] = {}): + def __init__(self, path: str, connection: 'Connection', metadata: Dict[str, Any] = None): self.path = path - self.metadata = metadata + self.metadata = metadata or {"path": path} self.connection = connection def __repr__(self): @@ -21,38 +23,41 @@ def __repr__(self): def _get_endpoint(self) -> str: return "/files/{}".format(self.path) - def get_metadata(self, key) -> Any: - """ Get metadata about the file, e.g. file size (key: 'size') or modification date (key: 'modified').""" - if key in self.metadata: - return self.metadata[key] - else: - return None + def download(self, target: Union[Path, str] = None) -> Path: + """ + Downloads a user-uploaded file to the given location. - def download_file(self, target: Union[Path, str]) -> Path: - """ Downloads a user-uploaded file.""" + :param target: download target path. Can be an existing folder + (in which case the file name advertised by backend will be used) + or full file name. By default, the working directory will be used. + """ # GET /files/{path} response = self.connection.get(self._get_endpoint(), expected_status=200, stream=True) - path = Path(target) - with path.open(mode="wb") as f: + target = Path(target or Path.cwd()) + if target.is_dir(): + target = target / os.path.basename(self.path) + ensure_dir(target.parent) + + with target.open(mode="wb") as f: for chunk in response.iter_content(chunk_size=None): f.write(chunk) - return path + return target - def upload_file(self, source: Union[Path, str]): + def upload(self, source: Union[Path, str]): # PUT /files/{path} """ Uploaded (or replaces) a user-uploaded file.""" path = Path(source) with path.open(mode="rb") as f: self.connection.put(self._get_endpoint(), expected_status=200, data=f) - def delete_file(self): + def delete(self): """ Delete a user-uploaded file.""" # DELETE /files/{path} self.connection.delete(self._get_endpoint(), expected_status=204) def to_dict(self) -> Dict[str, Any]: """ Returns the provided metadata as dict.""" - return self.metadata if "path" in self.metadata else {"path": self.path} \ No newline at end of file + return self.metadata \ No newline at end of file