From 014a411798e7b568be3a11bbfb5a4afadff7f8b4 Mon Sep 17 00:00:00 2001 From: szeyu Date: Tue, 3 Sep 2024 14:02:20 +0800 Subject: [PATCH 1/9] update npu models and engine setup --- README.md | 8 + docs/model/npu_models.md | 15 ++ requirements-npu.txt | 3 + setup.py | 9 + src/embeddedllm/backend/npu_engine.py | 268 +++++++++++++++++++++++++ src/embeddedllm/engine.py | 12 +- src/embeddedllm/entrypoints/modelui.py | 53 ++++- 7 files changed, 366 insertions(+), 2 deletions(-) create mode 100644 docs/model/npu_models.md create mode 100644 requirements-npu.txt create mode 100644 src/embeddedllm/backend/npu_engine.py diff --git a/README.md b/README.md index 526dcbb..5d371d8 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E * Onnxruntime CPU Models [Link](./docs/model/onnxruntime_cpu_models.md) * Ipex-LLM Models [Link](./docs/model/ipex_models.md) * OpenVINO-LLM Models [Link](./docs/model/openvino_models.md) + * NPU-LLM Models [Link](./docs/model/npu_models.md) ## Getting Started @@ -56,12 +57,14 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E - **CUDA:** `$env:ELLM_TARGET_DEVICE='cuda'; pip install -e .[cuda]` - **IPEX:** `$env:ELLM_TARGET_DEVICE='ipex'; python setup.py develop` - **OpenVINO:** `$env:ELLM_TARGET_DEVICE='openvino'; pip install -e .[openvino]` + - **NPU:** `$env:ELLM_TARGET_DEVICE='npu'; pip install -e .[npu]` - **With Web UI**: - **DirectML:** `$env:ELLM_TARGET_DEVICE='directml'; pip install -e .[directml,webui]` - **CPU:** `$env:ELLM_TARGET_DEVICE='cpu'; pip install -e .[cpu,webui]` - **CUDA:** `$env:ELLM_TARGET_DEVICE='cuda'; pip install -e .[cuda,webui]` - **IPEX:** `$env:ELLM_TARGET_DEVICE='ipex'; python setup.py develop; pip install -r requirements-webui.txt` - **OpenVINO:** `$env:ELLM_TARGET_DEVICE='openvino'; pip install -e .[openvino,webui]` + - **NPU:** `$env:ELLM_TARGET_DEVICE='npu'; pip install -e .[npu,webui]` - **Linux** @@ -77,12 +80,14 @@ Run local LLMs on iGPU, APU and CPU (AMD , Intel, and Qualcomm (Coming Soon)). E - **CUDA:** `ELLM_TARGET_DEVICE='cuda' pip install -e .[cuda]` - **IPEX:** `ELLM_TARGET_DEVICE='ipex' python setup.py develop` - **OpenVINO:** `ELLM_TARGET_DEVICE='openvino' pip install -e .[openvino]` + - **NPU:** `ELLM_TARGET_DEVICE='npu' pip install -e .[npu]` - **With Web UI**: - **DirectML:** `ELLM_TARGET_DEVICE='directml' pip install -e .[directml,webui]` - **CPU:** `ELLM_TARGET_DEVICE='cpu' pip install -e .[cpu,webui]` - **CUDA:** `ELLM_TARGET_DEVICE='cuda' pip install -e .[cuda,webui]` - **IPEX:** `ELLM_TARGET_DEVICE='ipex' python setup.py develop; pip install -r requirements-webui.txt` - **OpenVINO:** `ELLM_TARGET_DEVICE='openvino' pip install -e .[openvino,webui]` + - **NPU:** `ELLM_TARGET_DEVICE='npu' pip install -e .[npu,webui]` ### Launch OpenAI API Compatible Server @@ -161,6 +166,9 @@ _Powershell/Terminal Usage (Use it like `ellm_server`)_: # OpenVINO .\ellm_api_server.exe --model_path '.\meta-llama_Meta-Llama-3.1-8B-Instruct\' --backend 'openvino' --device 'gpu' --port 5555 --served_model_name 'meta-llama_Meta/Llama-3.1-8B-Instruct' + +# NPU +.\ellm_api_server.exe --model_path 'microsoft/Phi-3-mini-4k-instruct' --backend 'npu' --device 'npu' --port 5555 --served_model_name 'microsoft/Phi-3-mini-4k-instruct' ``` ## Acknowledgements diff --git a/docs/model/npu_models.md b/docs/model/npu_models.md new file mode 100644 index 0000000..c1d2b06 --- /dev/null +++ b/docs/model/npu_models.md @@ -0,0 +1,15 @@ +# Model Powered by NPU-LLM + +## Verified Models +Verified models can be found from EmbeddedLLM NPU-LLM model collections +* EmbeddedLLM NPU-LLM Model collections: [link](https://huggingface.co/collections/EmbeddedLLM/npu-llm-66d692817e6c9509bb8ead58) + +| Model | Model Link | +| --- | --- | +| Phi-3-mini-4k-instruct | [link](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) | +| Phi-3-mini-128k-instruct | [link](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) | +| Phi-3-medium-4k-instruct | [link](https://huggingface.co/microsoft/Phi-3-medium-4k-instruct) | +| Phi-3-medium-128k-instruct | [link](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct) | + +## Contribution +We welcome contributions to the verified model list. \ No newline at end of file diff --git a/requirements-npu.txt b/requirements-npu.txt new file mode 100644 index 0000000..dbcb8cf --- /dev/null +++ b/requirements-npu.txt @@ -0,0 +1,3 @@ +intel-npu-acceleration-library +torch>=2.4 +transformers>=4.42 \ No newline at end of file diff --git a/setup.py b/setup.py index 009aad2..a829c76 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,10 @@ def _is_openvino() -> bool: return ELLM_TARGET_DEVICE == "openvino" +def _is_npu() -> bool: + return ELLM_TARGET_DEVICE == "npu" + + class ELLMInstallCommand(install): def run(self): install.run(self) @@ -186,6 +190,8 @@ def get_requirements() -> List[str]: requirements = _read_requirements("requirements-ipex.txt") elif _is_openvino(): requirements = _read_requirements("requirements-openvino.txt") + elif _is_npu(): + requirements = _read_requirements("requirements-npu.txt") else: raise ValueError("Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.") return requirements @@ -204,6 +210,8 @@ def get_ellm_version() -> str: version += "+ipex" elif _is_openvino(): version += "+openvino" + elif _is_npu(): + version += "+npu" else: raise RuntimeError("Unknown runtime environment") @@ -256,6 +264,7 @@ def get_ellm_version() -> str: "cuda": ["onnxruntime-genai-cuda==0.3.0rc2"], "ipex": [], "openvino": [], + "npu": [], }, dependency_links=dependency_links, entry_points={ diff --git a/src/embeddedllm/backend/npu_engine.py b/src/embeddedllm/backend/npu_engine.py new file mode 100644 index 0000000..e00250c --- /dev/null +++ b/src/embeddedllm/backend/npu_engine.py @@ -0,0 +1,268 @@ +import contextlib +import time +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import AsyncIterator, List, Optional + +from loguru import logger +from PIL import Image +from transformers import ( + AutoConfig, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + TextIteratorStreamer, +) + +from threading import Thread + +import intel_npu_acceleration_library as npu_lib + +from embeddedllm.inputs import PromptInputs +from embeddedllm.protocol import CompletionOutput, RequestOutput +from embeddedllm.sampling_params import SamplingParams +from embeddedllm.backend.base_engine import BaseLLMEngine, _get_and_verify_max_len + +RECORD_TIMING = True + + +class NPUEngine(BaseLLMEngine): + def _init_(self, model_path: str, vision: bool, device: str = "npu"): + self.model_path = model_path + self.model_config: AutoConfig = AutoConfig.from_pretrained( + self.model_path, trust_remote_code=True + ) + self.device = device + + # model_config is to find out the max length of the model + self.max_model_len = _get_and_verify_max_len( + hf_config=self.model_config, + max_model_len=None, + disable_sliding_window=False, + sliding_window_len=self.get_hf_config_sliding_window(), + ) + + logger.info("Model Context Length: " + str(self.max_model_len)) + + try: + logger.info("Attempt to load fast tokenizer") + self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.model_path) + except Exception: + logger.info("Attempt to load slower tokenizer") + self.tokenizer = PreTrainedTokenizer.from_pretrained(self.model_path) + + self.model = npu_lib.NPUModelForCausalLM.from_pretrained( + self.model_path, + torch_dtype="auto", + dtype=npu_lib.int4, + trust_remote_code=True, + export=False + ) + + logger.info("Model loaded") + self.tokenizer_stream = TextIteratorStreamer( + self.tokenizer, skip_prompt=True, skip_special_tokens=True + ) + logger.info("Tokenizer created") + + self.vision = vision + + # if self.vision: + # self.onnx_processor = self.model.create_multimodal_processor() + # self.processor = AutoImageProcessor.from_pretrained( + # self.model_path, trust_remote_code=True + # ) + # print(dir(self.processor)) + + async def generate_vision( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + stream: bool = True, + ) -> AsyncIterator[RequestOutput]: + raise NotImplementedError(f"generate_vision yet to be implemented.") + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + stream: bool = True, + ) -> AsyncIterator[RequestOutput]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + """ + + prompt_text = inputs["prompt"] + input_token_length = None + input_tokens = None # for text only use case + # logger.debug("inputs: " + prompt_text) + + input_tokens = self.tokenizer.encode(prompt_text, return_tensors="pt") + # logger.debug(f"input_tokens: {input_tokens}") + input_token_length = len(input_tokens[0]) + + max_tokens = sampling_params.max_tokens + + assert input_token_length is not None + + if input_token_length + max_tokens > self.max_model_len: + raise ValueError("Exceed Context Length") + + generation_options = { + name: getattr(sampling_params, name) + for name in [ + "do_sample", + # "max_length", + "max_new_tokens", + "min_length", + "top_p", + "top_k", + "temperature", + "repetition_penalty", + ] + if hasattr(sampling_params, name) + } + generation_options["max_length"] = self.max_model_len + generation_options["input_ids"] = input_tokens.clone() + # generation_options["input_ids"] = input_tokens.clone().to(self.device) + generation_options["max_new_tokens"] = max_tokens + print(generation_options) + + token_list: List[int] = [] + output_text: str = "" + if stream: + generation_options["streamer"] = self.tokenizer_stream + if RECORD_TIMING: + started_timestamp = time.time() + first_token_timestamp = 0 + first = True + new_tokens = [] + try: + thread = Thread(target=self.model.generate, kwargs=generation_options) + started_timestamp = time.time() + first_token_timestamp = None + thread.start() + output_text = "" + first = True + for new_text in self.tokenizer_stream: + if new_text == "": + continue + if RECORD_TIMING: + if first: + first_token_timestamp = time.time() + first = False + # logger.debug(f"new text: {new_text}") + output_text += new_text + token_list = self.tokenizer.encode(output_text, return_tensors="pt") + + output = RequestOutput( + request_id=request_id, + prompt=prompt_text, + prompt_token_ids=input_tokens[0], + finished=False, + outputs=[ + CompletionOutput( + index=0, + text=output_text, + token_ids=token_list[0], + cumulative_logprob=-1.0, + ) + ], + ) + yield output + # logits = generator.get_output("logits") + # print(logits) + if RECORD_TIMING: + new_tokens = token_list[0] + + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + prompt_token_ids=input_tokens[0], + finished=True, + outputs=[ + CompletionOutput( + index=0, + text=output_text, + token_ids=token_list[0], + cumulative_logprob=-1.0, + finish_reason="stop", + ) + ], + ) + if RECORD_TIMING: + prompt_time = first_token_timestamp - started_timestamp + run_time = time.time() - first_token_timestamp + logger.info( + f"Prompt length: {len(input_tokens[0])}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens[0])/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps" + ) + + except Exception as e: + logger.error(str(e)) + + error_output = RequestOutput( + prompt=inputs, + prompt_token_ids=input_tokens, + finished=True, + request_id=request_id, + outputs=[ + CompletionOutput( + index=0, + text=output_text, + token_ids=token_list, + cumulative_logprob=-1.0, + finish_reason="error", + stop_reason=str(e), + ) + ], + ) + yield error_output + else: + try: + token_list = self.model.generate(**generation_options)[0] + + output_text = self.tokenizer.decode( + token_list[input_token_length:], skip_special_tokens=True + ) + + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + prompt_token_ids=input_tokens[0], + finished=True, + outputs=[ + CompletionOutput( + index=0, + text=output_text, + token_ids=token_list, + cumulative_logprob=-1.0, + finish_reason="stop", + ) + ], + ) + + except Exception as e: + logger.error(str(e)) + + error_output = RequestOutput( + prompt=prompt_text, + prompt_token_ids=input_tokens[0], + finished=True, + request_id=request_id, + outputs=[ + CompletionOutput( + index=0, + text=output_text, + token_ids=token_list, + cumulative_logprob=-1.0, + finish_reason="error", + stop_reason=str(e), + ) + ], + ) + yield error_output \ No newline at end of file diff --git a/src/embeddedllm/engine.py b/src/embeddedllm/engine.py index e2c5a9d..d3933e9 100644 --- a/src/embeddedllm/engine.py +++ b/src/embeddedllm/engine.py @@ -56,6 +56,16 @@ def __init__(self, model_path: str, vision: bool, device: str = "xpu", backend: self.engine = OnnxruntimeEngine(self.model_path, self.vision, self.device) logger.info(f"Initializing onnxruntime backend ({backend.upper()}): OnnxruntimeEngine") + + elif self.backend == "npu": + assert self.device == "npu", f"To run npu backend, device must be npu." + processor = get_processor_type() + assert processor == "Intel", f"Only support intel NPU" + from embeddedllm.backend.npu_engine import NPUEngine + + self.engine = NPUEngine(self.model_path, self.vision, self.device) + logger.info(f"Initializing npu backend (NPU): NPUEngine") + elif self.backend == "cpu": assert self.device == "cpu", f"To run `cpu` backend, `device` must be `cpu`." processor = get_processor_type() @@ -80,7 +90,7 @@ def __init__(self, model_path: str, vision: bool, device: str = "xpu", backend: else: raise ValueError( - f"EmbeddedLLMEngine only supports `cpu`, `ipex`, `cuda`, `openvino` and `directml`." + f"EmbeddedLLMEngine only supports `cpu`, `npu`, `ipex`, `cuda`, `openvino` and `directml`." ) self.tokenizer = self.engine.tokenizer diff --git a/src/embeddedllm/entrypoints/modelui.py b/src/embeddedllm/entrypoints/modelui.py index 9c82355..fb1922a 100644 --- a/src/embeddedllm/entrypoints/modelui.py +++ b/src/embeddedllm/entrypoints/modelui.py @@ -20,7 +20,7 @@ def get_embeddedllm_backend(): version = importlib.metadata.version("embeddedllm") # Use regex to extract the backend - match = re.search(r"\+(directml|cpu|cuda|ipex|openvino)$", version) + match = re.search(r"\+(directml|npu|cpu|cuda|ipex|openvino)$", version) if match: backend = match.group(1) @@ -260,6 +260,41 @@ class ModelCard(BaseModel): ), } +npu_model_dict_list = { + "microsoft/Phi-3-mini-4k-instruct": ModelCard( + hf_url="https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/tree/main/", + repo_id="microsoft/Phi-3-mini-4k-instruct", + model_name="Phi-3-mini-4k-instruct", + subfolder=".", + repo_type="model", + context_length=4096, + ), + "microsoft/Phi-3-mini-128k-instruct": ModelCard( + hf_url="https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/tree/main", + repo_id="microsoft/Phi-3-mini-128k-instruct", + model_name="Phi-3-mini-128k-instruct", + subfolder=".", + repo_type="model", + context_length=131072, + ), + "microsoft/Phi-3-medium-4k-instruct": ModelCard( + hf_url="https://huggingface.co/microsoft/Phi-3-medium-4k-instruct/tree/main", + repo_id="microsoft/Phi-3-medium-4k-instruct", + model_name="Phi-3-medium-4k-instruct", + subfolder=".", + repo_type="model", + context_length=4096, + ), + "microsoft/Phi-3-medium-128k-instruct": ModelCard( + hf_url="https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/tree/main", + repo_id="microsoft/Phi-3-medium-128k-instruct", + model_name="Phi-3-medium-128k-instruct", + subfolder=".", + repo_type="model", + context_length=131072, + ), +} + ipex_model_dict_list = { "microsoft/Phi-3-mini-4k-instruct": ModelCard( hf_url="https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/tree/main/", @@ -507,6 +542,11 @@ def compute_memory_size(repo_id, path_in_repo, repo_type: str = "model"): repo_id=v.repo_id, path_in_repo=v.subfolder, repo_type=v.repo_type ) +for k, v in npu_model_dict_list.items(): + v.size = compute_memory_size( + repo_id=v.repo_id, path_in_repo=v.subfolder, repo_type=v.repo_type + ) + for k, v in ipex_model_dict_list.items(): v.size = compute_memory_size( repo_id=v.repo_id, path_in_repo=v.subfolder, repo_type=v.repo_type @@ -603,6 +643,9 @@ def update_model_list(engine_type): if engine_type == "DirectML": models = sorted(list(dml_model_dict_list.keys())) models_pandas = convert_to_dataframe(dml_model_dict_list) + elif backend == "npu": + models = sorted(list(npu_model_dict_list.keys())) + models_pandas = convert_to_dataframe(npu_model_dict_list) elif backend == "ipex": models = sorted(list(ipex_model_dict_list.keys())) models_pandas = convert_to_dataframe(ipex_model_dict_list) @@ -631,6 +674,8 @@ def deploy_model(engine_type, model_name, port_number): if engine_type == "DirectML": llm_model_card = dml_model_dict_list[model_name] + elif backend == "npu": + llm_model_card = npu_model_dict_list[model_name] elif backend == "ipex": llm_model_card = ipex_model_dict_list[model_name] elif backend == "openvino": @@ -654,6 +699,8 @@ def deploy_model(engine_type, model_name, port_number): model_path = llm_model_card.repo_id print("Model path:", model_path) + if engine_type == "NPU": + device = "npu" if engine_type == "Ipex": device = "xpu" elif engine_type == "OpenVino": @@ -718,6 +765,8 @@ def download_model(engine_type, model_name): if engine_type == "DirectML": llm_model_card = dml_model_dict_list[model_name] + elif backend == "npu": + llm_model_card = npu_model_dict_list[model_name] elif backend == "ipex": llm_model_card = ipex_model_dict_list[model_name] elif backend == "openvino": @@ -771,6 +820,8 @@ def main(): if backend == "directml": default_value = "DirectML" + elif backend == "npu": + default_value = "NPU" elif backend == "ipex": default_value = "Ipex" elif backend == "openvino": From 3e92ce5fadb2c8530cdadf4623a66a0963695dbd Mon Sep 17 00:00:00 2001 From: Sze Yu Sim <34510821+szeyu@users.noreply.github.com> Date: Wed, 4 Sep 2024 13:24:39 +0800 Subject: [PATCH 2/9] Update README.md --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 5d371d8..d175971 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,9 @@ It is an interface that allows you to download and deploy OpenAI API compatible # OpenVINO ellm_server --model_path '.\meta-llama_Meta-Llama-3.1-8B-Instruct\' --backend 'openvino' --device 'gpu' --port 5555 --served_model_name 'meta-llama_Meta/Llama-3.1-8B-Instruct' + + # NPU + ellm_server --model_path 'microsoft/Phi-3-mini-4k-instruct' --backend 'openvino' --device 'gpu' --port 5555 --served_model_name 'microsoft/Phi-3-mini-4k-instruct' ``` ## Prebuilt OpenAI API Compatible Windows Executable (Alpha) From 736ea8537feabcb35249d176519d3dc5cecbc730 Mon Sep 17 00:00:00 2001 From: Sze Yu Sim <34510821+szeyu@users.noreply.github.com> Date: Wed, 4 Sep 2024 13:25:06 +0800 Subject: [PATCH 3/9] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d175971..6153509 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ It is an interface that allows you to download and deploy OpenAI API compatible ellm_server --model_path '.\meta-llama_Meta-Llama-3.1-8B-Instruct\' --backend 'openvino' --device 'gpu' --port 5555 --served_model_name 'meta-llama_Meta/Llama-3.1-8B-Instruct' # NPU - ellm_server --model_path 'microsoft/Phi-3-mini-4k-instruct' --backend 'openvino' --device 'gpu' --port 5555 --served_model_name 'microsoft/Phi-3-mini-4k-instruct' + ellm_server --model_path 'microsoft/Phi-3-mini-4k-instruct' --backend 'npu' --device 'npu' --port 5555 --served_model_name 'microsoft/Phi-3-mini-4k-instruct' ``` ## Prebuilt OpenAI API Compatible Windows Executable (Alpha) From 0bacaa75241650c6805bcce77001a5270b59701e Mon Sep 17 00:00:00 2001 From: szeyu Date: Wed, 4 Sep 2024 14:54:07 +0800 Subject: [PATCH 4/9] fix the typo of __init__ --- src/embeddedllm/backend/npu_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/embeddedllm/backend/npu_engine.py b/src/embeddedllm/backend/npu_engine.py index e00250c..c245e43 100644 --- a/src/embeddedllm/backend/npu_engine.py +++ b/src/embeddedllm/backend/npu_engine.py @@ -26,7 +26,7 @@ class NPUEngine(BaseLLMEngine): - def _init_(self, model_path: str, vision: bool, device: str = "npu"): + def __init__(self, model_path: str, vision: bool, device: str = "npu"): self.model_path = model_path self.model_config: AutoConfig = AutoConfig.from_pretrained( self.model_path, trust_remote_code=True From 2d730a3e5d9e15dc2f00ad7f527bd40b3f7e254c Mon Sep 17 00:00:00 2001 From: Sze Yu Sim <34510821+szeyu@users.noreply.github.com> Date: Wed, 4 Sep 2024 15:43:59 +0800 Subject: [PATCH 5/9] Update modelui.py fix the logic error of if else --- src/embeddedllm/entrypoints/modelui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/embeddedllm/entrypoints/modelui.py b/src/embeddedllm/entrypoints/modelui.py index fb1922a..81cb681 100644 --- a/src/embeddedllm/entrypoints/modelui.py +++ b/src/embeddedllm/entrypoints/modelui.py @@ -701,7 +701,7 @@ def deploy_model(engine_type, model_name, port_number): if engine_type == "NPU": device = "npu" - if engine_type == "Ipex": + elif engine_type == "Ipex": device = "xpu" elif engine_type == "OpenVino": device = "gpu" From 5504f889075ad1086bda220230da4a5263391963 Mon Sep 17 00:00:00 2001 From: Sze Yu Sim <34510821+szeyu@users.noreply.github.com> Date: Thu, 26 Sep 2024 12:59:20 +0800 Subject: [PATCH 6/9] [BUG FIXED] Update gradio version in requirements-webui.txt --- requirements-webui.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-webui.txt b/requirements-webui.txt index 8a51e1c..5fa09b7 100644 --- a/requirements-webui.txt +++ b/requirements-webui.txt @@ -1 +1 @@ -gradio~=4.36.1 \ No newline at end of file +gradio~=4.43.0 From abfab05c4eb0bfe5c32a52b7d06dca6f7b7ed2ea Mon Sep 17 00:00:00 2001 From: szeyu Date: Thu, 26 Sep 2024 14:48:58 +0800 Subject: [PATCH 7/9] update gitignore --- .gitignore | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index a25e14a..9a79f19 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,8 @@ scripts/*.ps1 scripts/*.sh **/dist **/build -*.log \ No newline at end of file +*.log +benchmark/ +modelTest/ +nc_workspace/ +debug_openai_history.txt \ No newline at end of file From d7586d430c05030e8a43a6b4bfe9bb2aa84bae2e Mon Sep 17 00:00:00 2001 From: szeyu Date: Fri, 4 Oct 2024 11:26:02 +0800 Subject: [PATCH 8/9] Renamed to npu_engine to intel_npu_engine to specify that it is intel processor --- src/embeddedllm/backend/{npu_engine.py => intel_npu_engine.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/embeddedllm/backend/{npu_engine.py => intel_npu_engine.py} (100%) diff --git a/src/embeddedllm/backend/npu_engine.py b/src/embeddedllm/backend/intel_npu_engine.py similarity index 100% rename from src/embeddedllm/backend/npu_engine.py rename to src/embeddedllm/backend/intel_npu_engine.py From e0d320f25d4b862cf693116c9d551edcd764cabc Mon Sep 17 00:00:00 2001 From: szeyu Date: Fri, 4 Oct 2024 11:26:43 +0800 Subject: [PATCH 9/9] Add support for Intel NPU backend and handle unsupported processors --- src/embeddedllm/engine.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/embeddedllm/engine.py b/src/embeddedllm/engine.py index d3933e9..b341472 100644 --- a/src/embeddedllm/engine.py +++ b/src/embeddedllm/engine.py @@ -60,11 +60,17 @@ def __init__(self, model_path: str, vision: bool, device: str = "xpu", backend: elif self.backend == "npu": assert self.device == "npu", f"To run npu backend, device must be npu." processor = get_processor_type() - assert processor == "Intel", f"Only support intel NPU" - from embeddedllm.backend.npu_engine import NPUEngine + if(processor == "Intel"): + from embeddedllm.backend.intel_npu_engine import NPUEngine + + self.engine = NPUEngine(self.model_path, self.vision, self.device) + logger.info(f"Initializing Intel npu backend (NPU): NPUEngine") + + elif(processor == "AMD"): + raise SystemError(f"NPU support on AMD platform is not supported yet.") - self.engine = NPUEngine(self.model_path, self.vision, self.device) - logger.info(f"Initializing npu backend (NPU): NPUEngine") + else: + raise SystemError(f"Unknown processor is not supported.") elif self.backend == "cpu": assert self.device == "cpu", f"To run `cpu` backend, `device` must be `cpu`."