From ebe4b165473cf08a1d4eebdf952a22ab308db107 Mon Sep 17 00:00:00 2001 From: Alex Carney Date: Fri, 17 Nov 2023 17:24:47 +0000 Subject: [PATCH] wip --- pygls/{client.py => client/__init__.py} | 0 pygls/client/_native.py | 90 +++++++ pygls/handler/_native.py | 311 ++++++++++++++++++++++++ pygls/protocol/next.py | 277 +++++++++++++++++++++ pygls/{server.py => server/__init__.py} | 0 pygls/server/_native.py | 69 ++++++ pygls/server/_wasm.py | 0 tests/servers/rpc.py | 21 ++ tests/test_server.py | 26 ++ 9 files changed, 794 insertions(+) rename pygls/{client.py => client/__init__.py} (100%) create mode 100644 pygls/client/_native.py create mode 100644 pygls/handler/_native.py create mode 100644 pygls/protocol/next.py rename pygls/{server.py => server/__init__.py} (100%) create mode 100644 pygls/server/_native.py create mode 100644 pygls/server/_wasm.py create mode 100644 tests/servers/rpc.py create mode 100644 tests/test_server.py diff --git a/pygls/client.py b/pygls/client/__init__.py similarity index 100% rename from pygls/client.py rename to pygls/client/__init__.py diff --git a/pygls/client/_native.py b/pygls/client/_native.py new file mode 100644 index 00000000..37651aae --- /dev/null +++ b/pygls/client/_native.py @@ -0,0 +1,90 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import asyncio +from typing import Optional +from typing import Type + +from pygls.handler._native import JsonRPCHandler +from pygls.handler._native import aio_main +from pygls.protocol.next import JsonRPCProtocol + + +class JsonRPCClient(JsonRPCHandler): + """Base JSON-RPC client for "native" runtimes""" + + def __init__( + self, *args, protocol_cls: Type[JsonRPCProtocol] = JsonRPCProtocol, **kwargs + ): + super().__init__(*args, protocol=protocol_cls(), **kwargs) + + self._server: Optional[asyncio.subprocess.Process] = None + + @property + def stopped(self) -> bool: + """Return ``True`` if the client has been stopped.""" + return self._stop_event.is_set() + + async def start_io(self, cmd: str, *args, **kwargs): + """Start the given server and communicate with it over stdio.""" + + self.logger.debug("Starting server process: %s", " ".join([cmd, *args])) + server = await asyncio.create_subprocess_exec( + cmd, + *args, + stdout=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + **kwargs, + ) + + assert server.stdout is not None, "Missing server stdout" + assert server.stdin is not None, "Missing server stdin" + + self._writer = server.stdin + self._tasks["<>"] = asyncio.create_task( + aio_main( + reader=server.stdout, + stop_event=self._stop_event, + message_handler=self, + ) + ) + self._tasks["<>"] = asyncio.create_task(self._server_exit()) + self._server = server + + async def _server_exit(self): + if self._server is not None: + await self._server.wait() + self.logger.debug( + "Server process %s exited with return code: %s", + self._server.pid, + self._server.returncode, + ) + await self.server_exit(self._server) + self._stop_event.set() + + async def server_exit(self, server: asyncio.subprocess.Process): + """Called when the server process exits.""" + + async def stop(self): + self._stop_event.set() + + if self._server is not None and self._server.returncode is None: + self.logger.debug("Terminating server process: %s", self._server.pid) + self._server.terminate() + + if len(self._tasks) > 0: + await asyncio.gather(*self._tasks.values()) diff --git a/pygls/handler/_native.py b/pygls/handler/_native.py new file mode 100644 index 00000000..ed214591 --- /dev/null +++ b/pygls/handler/_native.py @@ -0,0 +1,311 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import asyncio +import inspect +import logging +import re +import threading +from functools import partial +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import TypeVar +from typing import Union + +import pygls.protocol.next as json_rpc +from pygls.exceptions import JsonRpcException +from pygls.exceptions import JsonRpcMethodNotFound +from pygls.feature_manager import FeatureManager + +CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$") +MessageHandler = Callable[[bytes], None] +TaskOrFuture = Union[asyncio.Task[None], asyncio.Future[Any]] +T = TypeVar("T") + + +class JsonRPCHandler: + """JSON-RPC message handler for "native" runtimes.""" + + def __init__( + self, + protocol: json_rpc.JsonRPCProtocol, + logger: Optional[logging.Logger] = None, + ) -> None: + self.protocol = protocol + self.logger = logger or logging.getLogger(__name__) + + self._stop_event = threading.Event() + self._include_headers = True + self._writer: Optional[asyncio.StreamWriter] = None + self._features = FeatureManager(self) + self._tasks: Dict[json_rpc.MsgId, TaskOrFuture] = {} + + def __call__(self, data: bytes) -> Any: + try: + message = self.protocol.decode_message(data) + except JsonRpcException as exc: + self._error_handler(exc) + return + + # Generate a message id if required - that way we can reference any task by id. + msg_id: json_rpc.MsgId = getattr(message, "id", None) or self.protocol.make_id() + msg_type = type(message).__name__ + + if isinstance(message, json_rpc.JsonRPCRequestMessage): + coro = self._handle_request(message.id, message.method, message.params) + + elif isinstance(message, json_rpc.JsonRPCNotification): + coro = self._handle_notification(message.method, message.params) + + elif isinstance(message, json_rpc.JsonRPCResultMessage): + coro = self._handle_result(message.id, message.result) + + else: + coro = self._handle_error(message.id, message.error) + + if inspect.iscoroutine(coro): + task = asyncio.create_task(coro, name=f"{msg_type}: {msg_id}") + self._tasks[msg_id] = task + task.add_done_callback(partial(self._finish_task, msg_id)) + + def _finish_task(self, msg_id: json_rpc.MsgId, _t: TaskOrFuture): + """Cleanup a finished task or completed future.""" + if (task := self._tasks.pop(msg_id, None)) is not None: + if isinstance(task, asyncio.Task): + self.logger.debug("Task '%s' finished", task.get_name()) + else: + self.logger.debug("Future '%s' completed", msg_id) + + @property + def writer(self) -> asyncio.StreamWriter: + if self._writer is None: + raise RuntimeError("Unable to send data, writer not available!") + + return self._writer + + def _get_handler(self, method: str): + return self._features.builtin_features.get( + method, self._features.features.get(method, None) + ) + + async def send_error(self, msg_id: json_rpc.MsgId, error: Any): + """Send an error message.""" + message = self.protocol.encode_error(msg_id, error, self._include_headers) + self.writer.write(message) + await self.writer.drain() + + async def send_notification(self, method: str, params: Any): + """Send a notification message.""" + message = self.protocol.encode_notification( + method, params, self._include_headers + ) + self.writer.write(message) + await self.writer.drain() + + async def send_result(self, msg_id: json_rpc.MsgId, result: Any): + """Send a result message.""" + message = self.protocol.encode_result(msg_id, result, self._include_headers) + self.writer.write(message) + await self.writer.drain() + + async def _wait_for_response(self, msg_id: json_rpc.MsgId): + """Wait for a response to the given request id.""" + if msg_id not in self._tasks: + self.logger.error("Unknown message id: %s", msg_id) + return + + result = await self._tasks[msg_id] + + # Now that the response has been received, remove the future. + self._tasks.pop(msg_id) + return result + + async def send_request( + self, method: str, params: Any, msg_id: Optional[json_rpc.MsgId] = None + ): + """Send a request message.""" + + msg_id = msg_id or self.protocol.make_id() + request = self.protocol.encode_request(method, params=params, msg_id=msg_id) + + fut = asyncio.get_running_loop().create_future() + self._tasks[msg_id] = fut + + task_id = f"{msg_id}-response" + task = asyncio.create_task(self._wait_for_response(msg_id), name=task_id) + self._tasks[task_id] = task + task.add_done_callback(partial(self._finish_task, task_id)) + + self.writer.write(request) + await self.writer.drain() + + await task + return task.result() + + async def _handle_request(self, msg_id: json_rpc.MsgId, method: str, params: Any): + """Handle a JSON-RPC request message. + + Parameters + ---------- + msg_id + The message id. + + method + The method name + + params + The request parameters + """ + if (handler := self._get_handler(method)) is None: + error = JsonRpcMethodNotFound.of(method) + await self.send_error(msg_id, error.to_response_error()) + return + + result = handler(params) + if inspect.isawaitable(result): + result = await result + + await self.send_result(msg_id, result) + + async def _handle_notification(self, method: str, params: Any): + """Handle a JSON-RPC notification message. + + Parameters + ---------- + method + The method name + + params + The message parameters + """ + if (handler := self._get_handler(method)) is None: + error = JsonRpcMethodNotFound.of(method) + await self.send_error(msg_id, error.to_response_error()) + return + + result = handler(params) + if inspect.isawaitable(result): + result = await result + + def _handle_result(self, msg_id: json_rpc.MsgId, result: Any): + """Handle a JSON-RPC result message. + + Parameters + ---------- + msg_id + The message id + + result + The result + """ + if msg_id not in self._tasks: + self.logger.error("Received result for unknown message '%s'", msg_id) + return + + self._tasks[msg_id].set_result(result) + + def _handle_error(self, msg_id: json_rpc.MsgId, error: Any): + """Handle a JSON-RPC error message. + + Parameters + ---------- + msg_id + The message id + + error + The error + """ + if msg_id not in self._tasks: + self.logger.error( + "Received error response for unknown message '%s'", msg_id + ) + return + + exc = JsonRpcException.from_error(error) + self._tasks[msg_id].set_exception(exc) + + def error_handler(self, exc: Exception): + """Override to customize error handling""" + self.logger.error("%s", exc, exc_info=True) + + def _error_handler(self, exc: Exception): + try: + res = self.error_handler(exc) + + if inspect.iscoroutine(res): + msg_id = self.protocol.make_id() + task = asyncio.create_task(res, name=f"Error handler for: {exc}") + + self._tasks[msg_id] = task + task.add_done_callback(partial(self._finish_task, msg_id)) + except Exception: + self.logger.error("There was an error handling an error!!", exc_info=True) + + def feature( + self, + feature_name: str, + options: Optional[Any] = None, + ): + """Decorator used to register features. + + Example + ------- + :: + + import logging + from pygls.client import JsonRPCClient + + ls = JsonRPCClient() + + @ls.feature('window/logMessage') + def completions(ls, params): + logging.info("%s", params.message) + """ + return self._features.feature(feature_name, options) + + +async def aio_main( + reader: asyncio.StreamReader, + stop_event: threading.Event, + message_handler: MessageHandler, +): + """Main loop implemented using asyncio.""" + + content_length = 0 + + while not stop_event.is_set(): + # Read a header line + header = await reader.readline() + if not header: + break + + # Extract content length if possible + if not content_length: + match = CONTENT_LENGTH_PATTERN.fullmatch(header) + if match: + content_length = int(match.group(1)) + + # Check if all headers have been read (as indicated by an empty line \r\n) + if content_length and not header.strip(): + # Read body + body = await reader.readexactly(content_length) + if not body: + break + + message_handler(body) + content_length = 0 diff --git a/pygls/protocol/next.py b/pygls/protocol/next.py new file mode 100644 index 00000000..b37678df --- /dev/null +++ b/pygls/protocol/next.py @@ -0,0 +1,277 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import json +import logging +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Protocol +from typing import Type +from typing import Union +from uuid import uuid4 + +import attrs +import cattrs + +from pygls.exceptions import JsonRpcInternalError +from pygls.exceptions import JsonRpcInvalidParams + +MsgId = Union[str, int] +ConverterFactory = Callable[[], cattrs.Converter] +IdFactory = Callable[[], MsgId] + + +def uuids(): + return str(uuid4()) + + +@typing.runtime_checkable +class JsonRPCNotification(Protocol): + """Represents the generic shape of a json rpc notification message.""" + + method: str + jsonrpc: str + params: Any + + +@typing.runtime_checkable +class JsonRPCRequestMessage(Protocol): + """Represents the generic shape of a json rpc request message.""" + + id: MsgId + method: str + jsonrpc: str + params: Any + + +@typing.runtime_checkable +class JsonRPCResultMessage(Protocol): + """Represents the generic shape of a json rpc result message.""" + + id: MsgId + jsonrpc: str + result: Any + + +@typing.runtime_checkable +class JsonRPCErrorMessage(Protocol): + """Represents the generic shape of a json rpc error message.""" + + id: MsgId + jsonrpc: str + error: Any + + +JsonRPCMessage = Union[ + JsonRPCErrorMessage, + JsonRPCNotification, + JsonRPCRequestMessage, + JsonRPCResultMessage, +] + + +def basic_converter(): + converter = cattrs.Converter() + converter.register_structure_hook(Union[str, int], lambda o, t: o) + + return converter + + +class JsonRPCProtocol: + """JSON-RPC implementation.""" + + CONTENT_TYPE = "application/vscode-jsonrpc" + + VERSION = "2.0" + + def __init__( + self, + converter_factory: ConverterFactory = basic_converter, + encoding: str = "utf-8", + id_factory: IdFactory = uuids, + logger: Optional[logging.Logger] = None, + ): + self.converter = converter_factory() + self.id_factory = id_factory + self.encoding = encoding + self.logger = logger or logging.getLogger(__name__) + + self._result_types: Dict[ + Union[int, str], Optional[Type[JsonRPCResultMessage]] + ] = {} + + def get_notification_type(self, method: str) -> Optional[Type[JsonRPCNotification]]: + """Return the type definition of the notification associated with the given + method.""" + return None + + def get_request_type(self, method: str) -> Optional[Type[JsonRPCRequestMessage]]: + """Return the type definition of the result associated with the given method.""" + return None + + def get_result_type(self, method: str) -> Optional[Type[JsonRPCResultMessage]]: + """Return the type definition of the result associated with the given method.""" + return None + + def make_id(self) -> MsgId: + """Return a new message id.""" + # TODO: Include logic to make sure this is unique? + return self.id_factory() + + def encode_request( + self, + method: str, + params: Optional[Any] = None, + msg_id: Optional[MsgId] = None, + include_headers: bool = True, + ) -> bytes: + """Construct a JSON-RPC request to send.""" + + msg_id = msg_id or self.make_id() + + request_type = self.get_request_type(method) or _Request + request = request_type( + id=msg_id, method=method, params=params, jsonrpc=self.VERSION + ) + + # Lookup what the expected result type is for this message + self._result_types[msg_id] = self.get_result_type(method) + + data = self.converter.unstructure(request) + return self._encode_message(data, include_headers=include_headers) + + def encode_notification( + self, method: str, params: Optional[Any] = None, include_headers: bool = True + ) -> bytes: + """Construct a JSON-RPC notification to send.""" + + notification_type = self.get_notification_type(method) or _Notification + notification = notification_type( + method=method, params=params, jsonrpc=self.VERSION + ) + + data = self.converter.unstructure(notification) + return self._encode_message(data, include_headers=include_headers) + + def encode_error( + self, msg_id: MsgId, error: Optional[Any] = None, include_headers: bool = True + ) -> bytes: + """Construct a JSON-RPC error to send.""" + + response = _Error(id=msg_id, error=error, jsonrpc=self.VERSION) + data = self.converter.unstructure(response) + + return self._encode_message(data, include_headers=include_headers) + + def encode_result( + self, msg_id: MsgId, result: Optional[Any] = None, include_headers: bool = True + ) -> bytes: + """Construct a JSON-RPC result to send.""" + + response_type = self._result_types.pop(msg_id, None) or _Result + response = response_type(id=msg_id, result=result, jsonrpc=self.VERSION) + + data = self.converter.unstructure(response) + return self._encode_message(data, include_headers=include_headers) + + def _encode_message(self, data: Any, include_headers: bool) -> bytes: + """Encode the given data as bytes""" + body = json.dumps(data) + self.logger.debug("%s", body) + + if include_headers: + header = ( + f"Content-Length: {len(body)}\r\n" + f"Content-Type: {self.CONTENT_TYPE}; charset={self.encoding}\r\n\r\n" + ).encode(self.encoding) + return header + body.encode(self.encoding) + + return body.encode(self.encoding) + + def decode_message(self, data: bytes) -> JsonRPCMessage: + body = data.decode(self.encoding) + self.logger.debug(body) + return json.loads(body, object_hook=self._deserialize_message) + + def _deserialize_message(self, data: Any) -> Any: + """Function used to deserialize data recevied from the client.""" + + if "jsonrpc" not in data: + return data + + try: + if "id" in data: + if "error" in data: + return self.converter.structure(data, _Error) + elif "method" in data: + request_type = self.get_request_type(data["method"]) or _Request + return self.converter.structure(data, request_type) + else: + response_type = self._result_types.pop(data["id"]) or _Result + return self.converter.structure(data, response_type) + + else: + method = data.get("method", "") + notification_type = self.get_notification_type(method) or _Notification + return self.converter.structure(data, notification_type) + + except cattrs.ClassValidationError as exc: + self.logger.error("Unable to deserialize message\n%s", exc_info=True) + raise JsonRpcInvalidParams() from exc + + except Exception as exc: + self.logger.error("Unable to deserialize message\n%s", exc_info=True) + raise JsonRpcInternalError() from exc + + +@attrs.define +class _Notification: + """Fallback type representing a generic json rpc notification message.""" + + method: str + jsonrpc: str + params: Any + + +@attrs.define +class _Request: + """Fallback type representing a generic json rpc request message.""" + + id: MsgId + method: str + jsonrpc: str + params: Any + + +@attrs.define +class _Result: + """Fallback type representing a generic json rpc result message.""" + + id: MsgId + jsonrpc: str + result: Any + + +@attrs.define +class _Error: + """Fallback type representing a generic json rpc error message.""" + + id: MsgId + jsonrpc: str + error: Any diff --git a/pygls/server.py b/pygls/server/__init__.py similarity index 100% rename from pygls/server.py rename to pygls/server/__init__.py diff --git a/pygls/server/_native.py b/pygls/server/_native.py new file mode 100644 index 00000000..2d9b0801 --- /dev/null +++ b/pygls/server/_native.py @@ -0,0 +1,69 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import asyncio +import sys +from typing import BinaryIO +from typing import Optional +from typing import Tuple +from typing import Type + +from pygls.handler._native import JsonRPCHandler +from pygls.handler._native import aio_main +from pygls.protocol.next import JsonRPCProtocol + + +class JsonRPCServer(JsonRPCHandler): + """Base JSON-RPC server for "native" runtimes.""" + + def __init__( + self, *args, protocol_cls: Type[JsonRPCProtocol] = JsonRPCProtocol, **kwargs + ): + super().__init__(*args, protocol=protocol_cls(), **kwargs) + + async def start_io( + self, stdin: Optional[BinaryIO] = None, stdout: Optional[BinaryIO] = None + ): + stdin = stdin or sys.stdin.buffer + stdout = stdout or sys.stdout.buffer + + reader, writer = await get_sdtio_streams(stdin, stdout) + self._writer = writer + + await aio_main( + reader=reader, + stop_event=self._stop_event, + message_handler=self, + ) + + +async def get_sdtio_streams( + stdin: BinaryIO, stdout: BinaryIO +) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + """Get async stdio streams to use with the server.""" + + # TODO: This only works on linux! + loop = asyncio.get_running_loop() + + reader = asyncio.StreamReader() + read_protocol = asyncio.StreamReaderProtocol(reader) + await loop.connect_read_pipe(lambda: read_protocol, stdin) + + write_transport, write_protocol = await loop.connect_write_pipe( + asyncio.streams.FlowControlMixin, stdout + ) + writer = asyncio.StreamWriter(write_transport, write_protocol, reader, loop) + return reader, writer diff --git a/pygls/server/_wasm.py b/pygls/server/_wasm.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/servers/rpc.py b/tests/servers/rpc.py new file mode 100644 index 00000000..d1235f3b --- /dev/null +++ b/tests/servers/rpc.py @@ -0,0 +1,21 @@ +"""A generic JSON-RPC server""" +import asyncio +import logging +from typing import Dict + +from pygls.server._native import JsonRPCServer + +server = JsonRPCServer() + + +@server.feature("math/add") +def add(params: Dict[str, float]): + a = params["a"] + b = params["b"] + + return dict(sum=a + b) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG, filename="server.log", filemode="w") + asyncio.run(server.start_io()) diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 00000000..28789eb1 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,26 @@ +import pathlib +import sys + +import pytest + +from pygls.client._native import JsonRPCClient + +SERVERS = pathlib.Path(__file__).parent / "servers" + + +@pytest.fixture() +async def client(): + client_ = JsonRPCClient() + await client_.start_io(sys.executable, str(SERVERS / "rpc.py")) + + yield client_ + + await client_.stop() + + +async def test_invalid_method_async(client: JsonRPCClient): + """Ensure that the server responds with an appropriate error when the method is not + implemented.""" + + result = await client.send_request("math/add", dict(a=1, b=4)) + assert result["sum"] == 5