flashtool: add serial device validation

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2024-08-20 06:12:47 -04:00
parent 42909f8a0d
commit 7de954b916

View File

@ -15,7 +15,15 @@ import errno
import argparse import argparse
import hashlib import hashlib
import pathlib import pathlib
import shutil
import shlex
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
HAS_SERIAL = True
try:
from serial import Serial, SerialException
except ModuleNotFoundError:
HAS_SERIAL = False
SerialException = Exception
def output_line(msg: str) -> None: def output_line(msg: str) -> None:
sys.stdout.write(msg + "\n") sys.stdout.write(msg + "\n")
@ -85,7 +93,7 @@ class CanFlasher:
self.block_count = 0 self.block_count = 0
self.app_start_addr = 0 self.app_start_addr = 0
async def connect_btl(self): async def connect_btl(self) -> None:
output_line("Attempting to connect to bootloader") output_line("Attempting to connect to bootloader")
ret = await self.send_command('CONNECT') ret = await self.send_command('CONNECT')
pinfo = ret[:12] pinfo = ret[:12]
@ -213,6 +221,7 @@ class CanFlasher:
self.file_size = f.tell() self.file_size = f.tell()
f.seek(0) f.seek(0)
flash_address = self.app_start_addr flash_address = self.app_start_addr
recd_addr = 0
while True: while True:
buf = f.read(self.block_size) buf = f.read(self.block_size)
if not buf: if not buf:
@ -284,7 +293,7 @@ class CanFlasher:
class CanNode: class CanNode:
def __init__(self, node_id: int, cansocket: CanSocket) -> None: def __init__(self, node_id: int, cansocket: CanSocket | SerialSocket) -> None:
self.node_id = node_id self.node_id = node_id
self._reader = asyncio.StreamReader(CAN_READER_LIMIT) self._reader = asyncio.StreamReader(CAN_READER_LIMIT)
self._cansocket = cansocket self._cansocket = cansocket
@ -519,41 +528,90 @@ class CanSocket:
class SerialSocket: class SerialSocket:
def __init__(self, loop: asyncio.AbstractEventLoop): def __init__(self, loop: asyncio.AbstractEventLoop):
self._loop = loop self._loop = loop
self.serial = self.serial_error = None self.serial: Optional[Serial] = None
self.node = CanNode(0, self) self.node = CanNode(0, self)
def _handle_response(self) -> None: def _handle_response(self) -> None:
assert self.serial is not None
try: try:
data = self.serial.read(4096) data = self.serial.read(4096)
except self.serial_error: except SerialException:
logging.exception("Error on serial read") logging.exception("Error on serial read")
self.close() self.close()
else:
self.node.feed_data(data) self.node.feed_data(data)
def send(self, can_id: int, payload: bytes = b"") -> None: def send(self, can_id: int, payload: bytes = b"") -> None:
assert self.serial is not None
try: try:
self.serial.write(payload) self.serial.write(payload)
except self.serial_error: except SerialException:
logging.exception("Error on serial write") logging.exception("Error on serial write")
self.close() self.close()
async def _lookup_proc_name(self, process_id: str) -> str:
has_sysctl = shutil.which("systemctl") is not None
if has_sysctl:
cmd = shlex.split(f"systemctl status {process_id}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, _ = await proc.communicate()
resp = stdout.strip().decode(errors="ignore")
if resp:
unit = resp.split(maxsplit=2)
if len(unit) == 3:
return f"Systemd Unit Name: {unit[1]}"
cmdline_file = pathlib.Path(f"/proc/{process_id}/cmdline")
if cmdline_file.exists():
res = cmdline_file.read_text().replace("\x00", " ").strip()
return f"Command Line: {res}"
exe_file = pathlib.Path(f"/proc/{process_id}/exe")
if exe_file.exists():
return f"Executable: {exe_file.resolve()})"
return "Name Unknown"
async def validate_device(self, dev_strpath: str) -> None:
dev_path = pathlib.Path(dev_strpath)
if not dev_path.exists():
raise FlashError(f"No Serial Device found at {dev_path}")
try:
dev_st = dev_path.stat()
except PermissionError as e:
raise FlashError(f"No permission to access device {dev_path}") from e
dev_id = (dev_st.st_dev, dev_st.st_ino)
for fd_dir in pathlib.Path("/proc").glob("*/fd"):
pid = fd_dir.parent.name
if not pid.isdigit():
continue
try:
for item in fd_dir.iterdir():
item_st = item.stat()
item_id = (item_st.st_dev, item_st.st_ino)
if item_id == dev_id:
proc_name = await self._lookup_proc_name(pid)
output_line(
f"Serial device {dev_path} in use by another program.\n"
f"Process ID: {pid}\n"
f"Process {proc_name}"
)
raise FlashError(f"Serial device {dev_path} in use")
except PermissionError:
continue
async def run(self, intf: str, baud: int, fw_path: pathlib.Path) -> None: async def run(self, intf: str, baud: int, fw_path: pathlib.Path) -> None:
if not fw_path.is_file(): if not fw_path.is_file():
raise FlashError("Invalid firmware path '%s'" % (fw_path)) raise FlashError("Invalid firmware path '%s'" % (fw_path))
await self.validate_device(intf)
try: try:
import serial serial_dev = Serial( # type: ignore
except ModuleNotFoundError: baudrate=baud, timeout=0, exclusive=True
raise FlashError( )
"The pyserial python package was not found. To install "
"run the following command in a terminal: \n\n"
" pip3 install pyserial\n\n")
self.serial_error = serial.SerialException
try:
serial_dev = serial.Serial(baudrate=baud, timeout=0,
exclusive=True)
serial_dev.port = intf serial_dev.port = intf
serial_dev.open() serial_dev.open()
except (OSError, IOError, self.serial_error) as e: except (OSError, IOError, SerialException) as e:
raise FlashError("Unable to open serial port: %s" % (e,)) raise FlashError("Unable to open serial port: %s" % (e,))
self.serial = serial_dev self.serial = serial_dev
self._loop.add_reader(self.serial.fileno(), self._handle_response) self._loop.add_reader(self.serial.fileno(), self._handle_response)
@ -634,6 +692,15 @@ def main():
uuid = int(args.uuid, 16) uuid = int(args.uuid, 16)
loop.run_until_complete(sock.run(intf, uuid, fpath, req_only)) loop.run_until_complete(sock.run(intf, uuid, fpath, req_only))
else: else:
if not HAS_SERIAL:
ser_inst_cmd = "pip3 install serial"
if shutil.which("apt") is not None:
ser_inst_cmd = "sudo apt install python3-serial"
raise FlashError(
"The pyserial python package was not found. To install "
"run the following command in a terminal: \n\n"
f" {ser_inst_cmd}\n\n"
)
if args.device is None: if args.device is None:
raise FlashError( raise FlashError(
"The 'device' option must be specified to flash a device" "The 'device' option must be specified to flash a device"