Skip to content

Commit

Permalink
linting and type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
snhobbs committed Jun 5, 2024
1 parent 947cb12 commit 71714cb
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 32 deletions.
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
hooks:
# Linter
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes]
# Formatter
- id: ruff-format

Expand All @@ -42,6 +42,11 @@ repos:
- id: djlint-reformat-django
- id: djlint-django

- repo: https://github.com/pre-commit/mirrors-mypy
rev: '' # Use the sha / tag you want to point at
hooks:
- id: mypy

# sets up .pre-commit-ci.yaml to ensure pre-commit dependencies stay up to date
ci:
autoupdate_schedule: weekly
Expand Down
43 changes: 22 additions & 21 deletions src/isp_programmer/ISPConnection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import os
import time
import logging
Expand Down Expand Up @@ -137,7 +138,7 @@ def _write_serial(self, out: bytes) -> None:
self._delay_write_serial(out)
else:
self.iodevice.write(out)
logging.log(logging.DEBUG - 1, f"Write: [{out}]")
logging.log(logging.DEBUG - 1, f"Write: [{out.decode('utf-8')}]")

def _flush(self):
self.iodevice.flush()
Expand Down Expand Up @@ -353,7 +354,7 @@ def CheckSectorsBlank(self, start: int, end: int) -> bool:
_raise_return_code_error(response_code, "Blank Check Sectors")
return _return_code_success(response_code)

def ReadPartID(self) -> str:
def ReadPartID(self) -> int:
"""
Throws no exception
"""
Expand All @@ -366,11 +367,9 @@ def ReadPartID(self) -> str:
exception=timeout_decorator.TimeoutError,
raise_on_fail=False,
)()
try:
return int(resp) # handle none type passed
except ValueError:
pass
return resp
with contextlib.suppress(TypeError):
return int(resp)
return 0

def ReadBootCodeVersion(self):
"""
Expand Down Expand Up @@ -439,7 +438,7 @@ def ReadCRC(self, address: int, num_bytes: int) -> int:

def ReadFlashSig(
self, start: int, end: int, wait_states: int = 2, mode: int = 0
) -> str:
) -> list[str]:
assert start < end
response_code = self._write_command(f"Z {start} {end} {wait_states} {mode}")
_raise_return_code_error(response_code, "Read Flash Signature")
Expand Down Expand Up @@ -567,19 +566,20 @@ class ChipDescription:
"CRP3": 0x43218765,
}

def __init__(self, descriptor: dict):
def __init__(self, descriptor: dict[str, str]):
self.RAMRange = [0, 0]
self.RAMBufferSize = 0
self.FlashRange = [0, 0]

descriptor: dict
for name in dict(descriptor):
self.__setattr__(name, descriptor[name])
# for name in dict(descriptor):
# self.__setattr__(name, descriptor[name])
self.SectorCount: int = int(descriptor.pop("SectorCount"))
self.RAMStartWrite: int = int(descriptor.pop("RAMStartWrite"))
self.CrystalFrequency = 12000 # khz == 30MHz
self.kCheckSumLocation = 7 # 0x0000001c

@property
def MaxByteTransfer(self):
def MaxByteTransfer(self) -> int:
return self.RAMBufferSize

@property
Expand Down Expand Up @@ -789,7 +789,7 @@ def WriteBinaryToFlash(
logging.error(
f"Invalid sector count\t Start: {start_sector}\tCount: {sector_count}\tEnd: {chip.SectorCount}"
)
return
return 1
isp.Unlock()
for sector in reversed(range(start_sector, start_sector + sector_count)):
logging.info(f"\nWriting Sector {sector}")
Expand All @@ -810,6 +810,7 @@ def WriteBinaryToFlash(
logging.info("Programming Complete.")
return chip_flash_sig
"""
return 0


def WriteImage(
Expand Down Expand Up @@ -857,10 +858,10 @@ def ReadImage(isp: ISPConnection, chip: ChipDescription) -> bytes:
image = bytes()
blank_sector = FindFirstBlankSector(isp, chip)
logging.getLogger().info("First Blank Sector %d", blank_sector)
sectors = []
for sector in range(blank_sector):
logging.getLogger().info("Sector %d", sector)
sector = ReadSector(isp, chip, sector)
sectors: list[bytes] = []
for nsector in range(blank_sector):
logging.getLogger().info("Sector %d", nsector)
sector: bytes = ReadSector(isp, chip, nsector)
sectors.append(sector)

return image.join(sectors)
Expand All @@ -876,7 +877,7 @@ def MassErase(isp: ISPConnection, chip: ChipDescription):

def SetupChip(
baudrate: int,
device: object,
device: str,
crystal_frequency: int,
chip_file: str,
no_sync: bool = False,
Expand Down Expand Up @@ -907,7 +908,7 @@ def SetupChip(
kStartingBaudRate = BAUDRATES[0]

logging.info("baud rate %d", kStartingBaudRate)
iodevice = UartDevice(device, baudrate=kStartingBaudRate)
iodevice: UartDevice = UartDevice(device, baudrate=kStartingBaudRate)
isp = ISPConnection(iodevice)
isp.serial_sleep = serial_sleep
isp.return_code_sleep = sleep_time
Expand All @@ -924,7 +925,7 @@ def SetupChip(
isp.reset()
part_id = isp.ReadPartID()

descriptor = GetPartDescriptor(chip_file, part_id)
descriptor: dict[str, str] = GetPartDescriptor(chip_file, part_id)
logging.info(f"{part_id}, {descriptor}")
chip = ChipDescription(descriptor)
chip.CrystalFrequency = crystal_frequency
Expand Down
16 changes: 6 additions & 10 deletions src/isp_programmer/parts_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
]


def read_lpcparts_string(string: str):
def read_lpcparts_string(string: str) -> dict[str, list]:
lpc_tools_column_locations = {
"part id": 0,
"name": 1,
Expand All @@ -38,7 +38,7 @@ def read_lpcparts_string(string: str):
"RAMBufferSize": 9,
"UU Encode": 10,
}
df_dict = {}
df_dict: dict[str, list] = {}
for column in lpc_tools_column_locations:
df_dict[column] = []

Expand All @@ -48,11 +48,7 @@ def read_lpcparts_string(string: str):
continue
split_line = line.strip().split(",")
for column, index in lpc_tools_column_locations.items():
value = split_line[index].strip()
try:
value = int(value, 0)
except ValueError:
pass
value: int = int(split_line[index].strip(), 0)
df_dict[column].append(value)

for col in df_dict:
Expand All @@ -77,16 +73,16 @@ def ReadChipFile(fname: str) -> pandas.DataFrame:
return df


def GetPartDescriptorLine(fname: str, partid: int) -> list:
def GetPartDescriptorLine(fname: str, partid: int) -> dict[str, str]:
entries = ReadChipFile(fname)
for _, entry in entries.iterrows():
if partid == entry["part id"]:
print(partid, entry["part id"])
return entry
raise UserWarning(f"PartId {partid} not found in {fname}")


def GetPartDescriptor(fname: str, partid: int) -> dict:
def GetPartDescriptor(fname: str, partid: int) -> dict[str, str]:
# FIXME redundent function
descriptor = GetPartDescriptorLine(fname, partid)
if descriptor is None:
raise UserWarning("Warning chip %s not found in file %s" % (hex(partid), fname))
Expand Down

0 comments on commit 71714cb

Please sign in to comment.