From 71714cbe9b21862342a7d0655f9802eb125d0581 Mon Sep 17 00:00:00 2001 From: simon Date: Wed, 5 Jun 2024 00:03:15 -0400 Subject: [PATCH] linting and type checking --- .pre-commit-config.yaml | 7 +++- src/isp_programmer/ISPConnection.py | 43 +++++++++++++------------ src/isp_programmer/parts_definitions.py | 16 ++++----- 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ce040c1..f0fb9fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: hooks: # Linter - id: ruff - args: [--fix, --exit-non-zero-on-fix] + args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] # Formatter - id: ruff-format @@ -42,6 +42,11 @@ repos: - id: djlint-reformat-django - id: djlint-django + - repo: https://github.com/pre-commit/mirrors-mypy + rev: '' # Use the sha / tag you want to point at + hooks: + - id: mypy + # sets up .pre-commit-ci.yaml to ensure pre-commit dependencies stay up to date ci: autoupdate_schedule: weekly diff --git a/src/isp_programmer/ISPConnection.py b/src/isp_programmer/ISPConnection.py index d8a8066..4bbc295 100644 --- a/src/isp_programmer/ISPConnection.py +++ b/src/isp_programmer/ISPConnection.py @@ -1,3 +1,4 @@ +import contextlib import os import time import logging @@ -137,7 +138,7 @@ def _write_serial(self, out: bytes) -> None: self._delay_write_serial(out) else: self.iodevice.write(out) - logging.log(logging.DEBUG - 1, f"Write: [{out}]") + logging.log(logging.DEBUG - 1, f"Write: [{out.decode('utf-8')}]") def _flush(self): self.iodevice.flush() @@ -353,7 +354,7 @@ def CheckSectorsBlank(self, start: int, end: int) -> bool: _raise_return_code_error(response_code, "Blank Check Sectors") return _return_code_success(response_code) - def ReadPartID(self) -> str: + def ReadPartID(self) -> int: """ Throws no exception """ @@ -366,11 +367,9 @@ def ReadPartID(self) -> str: exception=timeout_decorator.TimeoutError, raise_on_fail=False, )() - try: - return int(resp) # handle none type passed - except ValueError: - pass - return resp + with contextlib.suppress(TypeError): + return int(resp) + return 0 def ReadBootCodeVersion(self): """ @@ -439,7 +438,7 @@ def ReadCRC(self, address: int, num_bytes: int) -> int: def ReadFlashSig( self, start: int, end: int, wait_states: int = 2, mode: int = 0 - ) -> str: + ) -> list[str]: assert start < end response_code = self._write_command(f"Z {start} {end} {wait_states} {mode}") _raise_return_code_error(response_code, "Read Flash Signature") @@ -567,19 +566,20 @@ class ChipDescription: "CRP3": 0x43218765, } - def __init__(self, descriptor: dict): + def __init__(self, descriptor: dict[str, str]): self.RAMRange = [0, 0] self.RAMBufferSize = 0 self.FlashRange = [0, 0] - descriptor: dict - for name in dict(descriptor): - self.__setattr__(name, descriptor[name]) + # for name in dict(descriptor): + # self.__setattr__(name, descriptor[name]) + self.SectorCount: int = int(descriptor.pop("SectorCount")) + self.RAMStartWrite: int = int(descriptor.pop("RAMStartWrite")) self.CrystalFrequency = 12000 # khz == 30MHz self.kCheckSumLocation = 7 # 0x0000001c @property - def MaxByteTransfer(self): + def MaxByteTransfer(self) -> int: return self.RAMBufferSize @property @@ -789,7 +789,7 @@ def WriteBinaryToFlash( logging.error( f"Invalid sector count\t Start: {start_sector}\tCount: {sector_count}\tEnd: {chip.SectorCount}" ) - return + return 1 isp.Unlock() for sector in reversed(range(start_sector, start_sector + sector_count)): logging.info(f"\nWriting Sector {sector}") @@ -810,6 +810,7 @@ def WriteBinaryToFlash( logging.info("Programming Complete.") return chip_flash_sig """ + return 0 def WriteImage( @@ -857,10 +858,10 @@ def ReadImage(isp: ISPConnection, chip: ChipDescription) -> bytes: image = bytes() blank_sector = FindFirstBlankSector(isp, chip) logging.getLogger().info("First Blank Sector %d", blank_sector) - sectors = [] - for sector in range(blank_sector): - logging.getLogger().info("Sector %d", sector) - sector = ReadSector(isp, chip, sector) + sectors: list[bytes] = [] + for nsector in range(blank_sector): + logging.getLogger().info("Sector %d", nsector) + sector: bytes = ReadSector(isp, chip, nsector) sectors.append(sector) return image.join(sectors) @@ -876,7 +877,7 @@ def MassErase(isp: ISPConnection, chip: ChipDescription): def SetupChip( baudrate: int, - device: object, + device: str, crystal_frequency: int, chip_file: str, no_sync: bool = False, @@ -907,7 +908,7 @@ def SetupChip( kStartingBaudRate = BAUDRATES[0] logging.info("baud rate %d", kStartingBaudRate) - iodevice = UartDevice(device, baudrate=kStartingBaudRate) + iodevice: UartDevice = UartDevice(device, baudrate=kStartingBaudRate) isp = ISPConnection(iodevice) isp.serial_sleep = serial_sleep isp.return_code_sleep = sleep_time @@ -924,7 +925,7 @@ def SetupChip( isp.reset() part_id = isp.ReadPartID() - descriptor = GetPartDescriptor(chip_file, part_id) + descriptor: dict[str, str] = GetPartDescriptor(chip_file, part_id) logging.info(f"{part_id}, {descriptor}") chip = ChipDescription(descriptor) chip.CrystalFrequency = crystal_frequency diff --git a/src/isp_programmer/parts_definitions.py b/src/isp_programmer/parts_definitions.py index 8759cb4..0265e5e 100644 --- a/src/isp_programmer/parts_definitions.py +++ b/src/isp_programmer/parts_definitions.py @@ -24,7 +24,7 @@ ] -def read_lpcparts_string(string: str): +def read_lpcparts_string(string: str) -> dict[str, list]: lpc_tools_column_locations = { "part id": 0, "name": 1, @@ -38,7 +38,7 @@ def read_lpcparts_string(string: str): "RAMBufferSize": 9, "UU Encode": 10, } - df_dict = {} + df_dict: dict[str, list] = {} for column in lpc_tools_column_locations: df_dict[column] = [] @@ -48,11 +48,7 @@ def read_lpcparts_string(string: str): continue split_line = line.strip().split(",") for column, index in lpc_tools_column_locations.items(): - value = split_line[index].strip() - try: - value = int(value, 0) - except ValueError: - pass + value: int = int(split_line[index].strip(), 0) df_dict[column].append(value) for col in df_dict: @@ -77,16 +73,16 @@ def ReadChipFile(fname: str) -> pandas.DataFrame: return df -def GetPartDescriptorLine(fname: str, partid: int) -> list: +def GetPartDescriptorLine(fname: str, partid: int) -> dict[str, str]: entries = ReadChipFile(fname) for _, entry in entries.iterrows(): if partid == entry["part id"]: - print(partid, entry["part id"]) return entry raise UserWarning(f"PartId {partid} not found in {fname}") -def GetPartDescriptor(fname: str, partid: int) -> dict: +def GetPartDescriptor(fname: str, partid: int) -> dict[str, str]: + # FIXME redundent function descriptor = GetPartDescriptorLine(fname, partid) if descriptor is None: raise UserWarning("Warning chip %s not found in file %s" % (hex(partid), fname))