From 7de954b916329874fe8cd133bef85108eeb6b948 Mon Sep 17 00:00:00 2001 From: Eric Callahan Date: Tue, 20 Aug 2024 06:12:47 -0400 Subject: [PATCH] flashtool: add serial device validation Signed-off-by: Eric Callahan --- scripts/flashtool.py | 101 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 84 insertions(+), 17 deletions(-) diff --git a/scripts/flashtool.py b/scripts/flashtool.py index 7b2f76d..f8718f4 100755 --- a/scripts/flashtool.py +++ b/scripts/flashtool.py @@ -15,7 +15,15 @@ import errno import argparse import hashlib import pathlib +import shutil +import shlex from typing import Dict, List, Optional, Union +HAS_SERIAL = True +try: + from serial import Serial, SerialException +except ModuleNotFoundError: + HAS_SERIAL = False + SerialException = Exception def output_line(msg: str) -> None: sys.stdout.write(msg + "\n") @@ -85,7 +93,7 @@ class CanFlasher: self.block_count = 0 self.app_start_addr = 0 - async def connect_btl(self): + async def connect_btl(self) -> None: output_line("Attempting to connect to bootloader") ret = await self.send_command('CONNECT') pinfo = ret[:12] @@ -213,6 +221,7 @@ class CanFlasher: self.file_size = f.tell() f.seek(0) flash_address = self.app_start_addr + recd_addr = 0 while True: buf = f.read(self.block_size) if not buf: @@ -284,7 +293,7 @@ class CanFlasher: class CanNode: - def __init__(self, node_id: int, cansocket: CanSocket) -> None: + def __init__(self, node_id: int, cansocket: CanSocket | SerialSocket) -> None: self.node_id = node_id self._reader = asyncio.StreamReader(CAN_READER_LIMIT) self._cansocket = cansocket @@ -519,41 +528,90 @@ class CanSocket: class SerialSocket: def __init__(self, loop: asyncio.AbstractEventLoop): self._loop = loop - self.serial = self.serial_error = None + self.serial: Optional[Serial] = None self.node = CanNode(0, self) def _handle_response(self) -> None: + assert self.serial is not None try: data = self.serial.read(4096) - except self.serial_error: + except SerialException: logging.exception("Error on serial read") self.close() - self.node.feed_data(data) + else: + self.node.feed_data(data) def send(self, can_id: int, payload: bytes = b"") -> None: + assert self.serial is not None try: self.serial.write(payload) - except self.serial_error: + except SerialException: logging.exception("Error on serial write") self.close() + async def _lookup_proc_name(self, process_id: str) -> str: + has_sysctl = shutil.which("systemctl") is not None + if has_sysctl: + cmd = shlex.split(f"systemctl status {process_id}") + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + stdout, _ = await proc.communicate() + resp = stdout.strip().decode(errors="ignore") + if resp: + unit = resp.split(maxsplit=2) + if len(unit) == 3: + return f"Systemd Unit Name: {unit[1]}" + cmdline_file = pathlib.Path(f"/proc/{process_id}/cmdline") + if cmdline_file.exists(): + res = cmdline_file.read_text().replace("\x00", " ").strip() + return f"Command Line: {res}" + exe_file = pathlib.Path(f"/proc/{process_id}/exe") + if exe_file.exists(): + return f"Executable: {exe_file.resolve()})" + return "Name Unknown" + + async def validate_device(self, dev_strpath: str) -> None: + dev_path = pathlib.Path(dev_strpath) + if not dev_path.exists(): + raise FlashError(f"No Serial Device found at {dev_path}") + try: + dev_st = dev_path.stat() + except PermissionError as e: + raise FlashError(f"No permission to access device {dev_path}") from e + dev_id = (dev_st.st_dev, dev_st.st_ino) + for fd_dir in pathlib.Path("/proc").glob("*/fd"): + pid = fd_dir.parent.name + if not pid.isdigit(): + continue + try: + for item in fd_dir.iterdir(): + item_st = item.stat() + item_id = (item_st.st_dev, item_st.st_ino) + if item_id == dev_id: + proc_name = await self._lookup_proc_name(pid) + output_line( + f"Serial device {dev_path} in use by another program.\n" + f"Process ID: {pid}\n" + f"Process {proc_name}" + ) + raise FlashError(f"Serial device {dev_path} in use") + except PermissionError: + continue + async def run(self, intf: str, baud: int, fw_path: pathlib.Path) -> None: if not fw_path.is_file(): raise FlashError("Invalid firmware path '%s'" % (fw_path)) + await self.validate_device(intf) try: - import serial - except ModuleNotFoundError: - raise FlashError( - "The pyserial python package was not found. To install " - "run the following command in a terminal: \n\n" - " pip3 install pyserial\n\n") - self.serial_error = serial.SerialException - try: - serial_dev = serial.Serial(baudrate=baud, timeout=0, - exclusive=True) + serial_dev = Serial( # type: ignore + baudrate=baud, timeout=0, exclusive=True + ) serial_dev.port = intf serial_dev.open() - except (OSError, IOError, self.serial_error) as e: + except (OSError, IOError, SerialException) as e: raise FlashError("Unable to open serial port: %s" % (e,)) self.serial = serial_dev self._loop.add_reader(self.serial.fileno(), self._handle_response) @@ -634,6 +692,15 @@ def main(): uuid = int(args.uuid, 16) loop.run_until_complete(sock.run(intf, uuid, fpath, req_only)) else: + if not HAS_SERIAL: + ser_inst_cmd = "pip3 install serial" + if shutil.which("apt") is not None: + ser_inst_cmd = "sudo apt install python3-serial" + raise FlashError( + "The pyserial python package was not found. To install " + "run the following command in a terminal: \n\n" + f" {ser_inst_cmd}\n\n" + ) if args.device is None: raise FlashError( "The 'device' option must be specified to flash a device"