diff --git a/lumen/sources/dremio_utils.py b/lumen/sources/dremio_utils.py new file mode 100644 index 00000000..6343b2b6 --- /dev/null +++ b/lumen/sources/dremio_utils.py @@ -0,0 +1,69 @@ +from functools import reduce + +from pyarrow import flight + + +class DremioClientAuthMiddleware(flight.ClientMiddleware): + """ + A ClientMiddleware that extracts the bearer token from + the authorization header returned by the Dremio + Flight Server Endpoint. + + Parameters + ---------- + factory : ClientHeaderAuthMiddlewareFactory + The factory to set call credentials if an + authorization header with bearer token is + returned by the Dremio server. + """ + + def __init__(self, factory): + self.factory = factory + + def received_headers(self, headers): + if self.factory.call_credential: + return + + auth_header_key = "authorization" + + authorization_header = reduce( + lambda result, header: ( + header[1] if header[0] == auth_header_key else result + ), + headers.items(), + ) + if not authorization_header: + raise Exception("Did not receive authorization header back from server.") + bearer_token = authorization_header[1][0] + self.factory.set_call_credential( + [b"authorization", bearer_token.encode("utf-8")] + ) + + +class DremioClientAuthMiddlewareFactory(flight.ClientMiddlewareFactory): + """A factory that creates DremioClientAuthMiddleware(s).""" + + def __init__(self): + self.call_credential = [] + + def start_call(self, info): + return DremioClientAuthMiddleware(self) + + def set_call_credential(self, call_credential): + self.call_credential = call_credential + + +class HttpDremioClientAuthHandler(flight.ClientAuthHandler): + + def __init__(self, username, password): + super(flight.ClientAuthHandler, self).__init__() + self.basic_auth = flight.BasicAuth(username, password) + self.token = None + + def authenticate(self, outgoing, incoming): + auth = self.basic_auth.serialize() + outgoing.write(auth) + self.token = incoming.read() + + def get_token(self): + return self.token diff --git a/lumen/sources/duckdb.py b/lumen/sources/duckdb.py index d85f6c1c..a78ae932 100644 --- a/lumen/sources/duckdb.py +++ b/lumen/sources/duckdb.py @@ -239,8 +239,8 @@ def get_tables(self): def get_sql_expr(self, table: str): if isinstance(self.tables, dict): table = self.tables[table] - if '(' not in table and ')' not in table: - table = f'"{table}"' + if '(' not in table and ')' not in table and '"' not in table: + table = f"'{table}'" if 'select ' in table.lower(): sql_expr = table else: @@ -274,6 +274,11 @@ def get_schema( schemas = {} sql_limit = SQLLimit(limit=limit or 1) for entry in tables: + + # duckdb does not support "Ahierachy"."Btable" + if '."' in entry: + entry = entry.replace('"', '') + if not self.load_schema: schemas[entry] = {} continue @@ -314,3 +319,187 @@ def get_schema( schema[col]['inclusiveMinimum'] = cast(minmax_data[f'{col}_min'].iloc[0]) schema[col]['inclusiveMaximum'] = cast(minmax_data[f'{col}_max'].iloc[0]) return schemas if table is None else schemas[table] + + +class DremioDuckDBSource(DuckDBSource): + """ + DremioDuckDBSource provides a simple wrapper around the DuckDB SQL + connector, extended to connect to a Dremio server via Apache Arrow Flight. + """ + + cert = param.String(default="Path to certificate file", doc="Path to the certificate file.") + + dremio_uri = param.String(doc="URI of the Dremio server.") + + tls = param.Boolean(default=True, doc="Enable encryption (TLS).") + + username = param.String(default=None, doc="Dremio username.") + + password = param.String(default=None, doc="Dremio password or token.") + + dialect = 'dremio' + + def __init__(self, **params): + from pyarrow import flight + + from lumen.sources.dremio_utils import ( + DremioClientAuthMiddlewareFactory, HttpDremioClientAuthHandler, + ) + + super().__init__(**params) + + protocol, hostname, username, password = self._process_uri( + tls=self.tls, username=self.username, password=self.password) + + dremio_client_auth_middleware = DremioClientAuthMiddlewareFactory() + connection_args = {'middleware': [dremio_client_auth_middleware]} + + if self.tls: + with open(self.cert) as f: + certs = f.read() + connection_args["tls_root_certs"] = certs + + dremio_client_auth_middleware = DremioClientAuthMiddlewareFactory() + connection_args = {'middleware': [dremio_client_auth_middleware]} + if self.tls: + connection_args["tls_root_certs"] = certs + self._dremio_client = flight.FlightClient(f'{protocol}://{hostname}', **connection_args) + auth_options = flight.FlightCallOptions() + try: + bearer_token = self._dremio_client.authenticate_basic_token(username, password) + self._headers = [bearer_token] + except Exception as e: + if self.tls: + raise e + handler = HttpDremioClientAuthHandler(username, password) + self._dremio_client.authenticate(handler, options=auth_options) + self._headers = [] + + + def _process_uri(self, tls=False, username=None, password=None): + """ + Extracts hostname, protocol, username and passworrd from URI + + Parameters + ---------- + uri: str or None + Connection string in the form username:password@hostname:port + tls: boolean + Whether TLS is enabled + username: str or None + Username if not supplied as part of the URI + password: str or None + Password if not supplied as part of the URI + """ + uri = self.dremio_uri + if "://" in uri: + protocol, uri = uri.split("://") + else: + protocol = "grpc+tls" if tls else "grpc+tcp" + if "@" in uri: + if username or password: + raise ValueError( + "Dremio URI must not include username and password " + "if they were supplied explicitly." + ) + userinfo, hostname = uri.split("@") + username, password = userinfo.split(":") + elif not (username and password): + raise ValueError( + "Dremio URI must include username and password " + "or they must be provided explicitly." + ) + else: + hostname = uri + return protocol, hostname, username, password + + def _execute_dremio_sql(self, sql_expr, as_pandas=False): + """ + Executes a SQL expression on Dremio via Apache Arrow Flight. + """ + from pyarrow import flight + + flight_desc = flight.FlightDescriptor.for_command(sql_expr) + options = flight.FlightCallOptions(headers=self._headers) + flight_info = self._dremio_client.get_flight_info(flight_desc, options) + reader = self._dremio_client.do_get(flight_info.endpoints[0].ticket, options) + data_table = reader.read_all() if not as_pandas else reader.read_pandas() + return data_table + + def _ingest_table(self, table): + """ + Ingests a table from Dremio to DuckDB. + """ + from pyarrow.dataset import dataset as arrow_dataset + + sql_expr = self.get_sql_expr(table, quoted=True) + data_table = self._execute_dremio_sql(sql_expr) + arrow_ds = arrow_dataset(source=[data_table]) + self._connection.from_arrow(arrow_ds).to_view(self._unquote(table)) + + def _quote(self, table): + """ + Convert ABC.DEF to "ABC"."DEF" for Dremio. + """ + return re.sub(r'(\w+)\.(\w+)', r'"\1"."\2"', table) + + def _unquote(self, table): + """ + Convert "ABC"."DEF" to ABC.DEF for DuckDB. + """ + return re.sub(r'"(\w+)"\."(\w+)"', r'\1.\2', table) + + def get_sql_expr(self, table: str, quoted: bool = False): + """ + Returns the SQL expression for a given table. + + Parameters + ---------- + table: str + The name of the table. + quoted: bool + Whether to quote the table name. + """ + if quoted: + table = self._quote(table) + else: + table = self._unquote(table) + sql_expr = super().get_sql_expr(table) + return sql_expr + + def get_tables(self): + if isinstance(self.tables, (dict, list)): + return list(self.tables) + + databases = self._execute_dremio_sql("SHOW DATABASES", as_pandas=True)["SCHEMA_NAME"] + + all_tables = [] + for database in databases: + if not database[0].isalpha(): + continue + try: + sql_expr = f"SHOW TABLES IN {database}" + tables = self._execute_dremio_sql(sql_expr, as_pandas=True)["TABLE_NAME"] + except Exception: + pass + for table in tables: + all_tables.append(f'"{database}"."{table}"') + return all_tables + + @cached_schema + def get_schema( + self, table: str | None = None, limit: int | None = None + ) -> dict[str, dict[str, Any]] | dict[str, Any]: + ingested_tables = self._connection.execute("SHOW TABLES").fetchdf()["name"] + table = self._unquote(table) + if table not in ingested_tables: + self._ingest_table(table) + return super().get_schema(table, limit) + + @cached + def get(self, table, **query): + ingested_tables = self._connection.execute("SHOW TABLES").fetchdf()["name"] + table = self._unquote(table) + if table not in ingested_tables: + self._ingest_table(table) + return super().get(table, **query)