flashtool: add serial device validation

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2024-08-20 06:12:47 -04:00
parent 42909f8a0d
commit 7de954b916

View File

@ -15,7 +15,15 @@ import errno
import argparse
import hashlib
import pathlib
import shutil
import shlex
from typing import Dict, List, Optional, Union
HAS_SERIAL = True
try:
from serial import Serial, SerialException
except ModuleNotFoundError:
HAS_SERIAL = False
SerialException = Exception
def output_line(msg: str) -> None:
sys.stdout.write(msg + "\n")
@ -85,7 +93,7 @@ class CanFlasher:
self.block_count = 0
self.app_start_addr = 0
async def connect_btl(self):
async def connect_btl(self) -> None:
output_line("Attempting to connect to bootloader")
ret = await self.send_command('CONNECT')
pinfo = ret[:12]
@ -213,6 +221,7 @@ class CanFlasher:
self.file_size = f.tell()
f.seek(0)
flash_address = self.app_start_addr
recd_addr = 0
while True:
buf = f.read(self.block_size)
if not buf:
@ -284,7 +293,7 @@ class CanFlasher:
class CanNode:
def __init__(self, node_id: int, cansocket: CanSocket) -> None:
def __init__(self, node_id: int, cansocket: CanSocket | SerialSocket) -> None:
self.node_id = node_id
self._reader = asyncio.StreamReader(CAN_READER_LIMIT)
self._cansocket = cansocket
@ -519,41 +528,90 @@ class CanSocket:
class SerialSocket:
def __init__(self, loop: asyncio.AbstractEventLoop):
self._loop = loop
self.serial = self.serial_error = None
self.serial: Optional[Serial] = None
self.node = CanNode(0, self)
def _handle_response(self) -> None:
assert self.serial is not None
try:
data = self.serial.read(4096)
except self.serial_error:
except SerialException:
logging.exception("Error on serial read")
self.close()
else:
self.node.feed_data(data)
def send(self, can_id: int, payload: bytes = b"") -> None:
assert self.serial is not None
try:
self.serial.write(payload)
except self.serial_error:
except SerialException:
logging.exception("Error on serial write")
self.close()
async def _lookup_proc_name(self, process_id: str) -> str:
has_sysctl = shutil.which("systemctl") is not None
if has_sysctl:
cmd = shlex.split(f"systemctl status {process_id}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, _ = await proc.communicate()
resp = stdout.strip().decode(errors="ignore")
if resp:
unit = resp.split(maxsplit=2)
if len(unit) == 3:
return f"Systemd Unit Name: {unit[1]}"
cmdline_file = pathlib.Path(f"/proc/{process_id}/cmdline")
if cmdline_file.exists():
res = cmdline_file.read_text().replace("\x00", " ").strip()
return f"Command Line: {res}"
exe_file = pathlib.Path(f"/proc/{process_id}/exe")
if exe_file.exists():
return f"Executable: {exe_file.resolve()})"
return "Name Unknown"
async def validate_device(self, dev_strpath: str) -> None:
dev_path = pathlib.Path(dev_strpath)
if not dev_path.exists():
raise FlashError(f"No Serial Device found at {dev_path}")
try:
dev_st = dev_path.stat()
except PermissionError as e:
raise FlashError(f"No permission to access device {dev_path}") from e
dev_id = (dev_st.st_dev, dev_st.st_ino)
for fd_dir in pathlib.Path("/proc").glob("*/fd"):
pid = fd_dir.parent.name
if not pid.isdigit():
continue
try:
for item in fd_dir.iterdir():
item_st = item.stat()
item_id = (item_st.st_dev, item_st.st_ino)
if item_id == dev_id:
proc_name = await self._lookup_proc_name(pid)
output_line(
f"Serial device {dev_path} in use by another program.\n"
f"Process ID: {pid}\n"
f"Process {proc_name}"
)
raise FlashError(f"Serial device {dev_path} in use")
except PermissionError:
continue
async def run(self, intf: str, baud: int, fw_path: pathlib.Path) -> None:
if not fw_path.is_file():
raise FlashError("Invalid firmware path '%s'" % (fw_path))
await self.validate_device(intf)
try:
import serial
except ModuleNotFoundError:
raise FlashError(
"The pyserial python package was not found. To install "
"run the following command in a terminal: \n\n"
" pip3 install pyserial\n\n")
self.serial_error = serial.SerialException
try:
serial_dev = serial.Serial(baudrate=baud, timeout=0,
exclusive=True)
serial_dev = Serial( # type: ignore
baudrate=baud, timeout=0, exclusive=True
)
serial_dev.port = intf
serial_dev.open()
except (OSError, IOError, self.serial_error) as e:
except (OSError, IOError, SerialException) as e:
raise FlashError("Unable to open serial port: %s" % (e,))
self.serial = serial_dev
self._loop.add_reader(self.serial.fileno(), self._handle_response)
@ -634,6 +692,15 @@ def main():
uuid = int(args.uuid, 16)
loop.run_until_complete(sock.run(intf, uuid, fpath, req_only))
else:
if not HAS_SERIAL:
ser_inst_cmd = "pip3 install serial"
if shutil.which("apt") is not None:
ser_inst_cmd = "sudo apt install python3-serial"
raise FlashError(
"The pyserial python package was not found. To install "
"run the following command in a terminal: \n\n"
f" {ser_inst_cmd}\n\n"
)
if args.device is None:
raise FlashError(
"The 'device' option must be specified to flash a device"