diff --git a/scripts/flashtool.py b/scripts/flashtool.py index 844b379..32e8cab 100755 --- a/scripts/flashtool.py +++ b/scripts/flashtool.py @@ -7,6 +7,8 @@ from __future__ import annotations import sys import os +import termios +import fcntl import zlib import json import asyncio @@ -81,9 +83,55 @@ CANBUS_CMD_CLEAR_NODE_ID = 0x12 CANBUS_RESP_NEED_NODEID = 0x20 CANBUS_NODEID_OFFSET = 128 +# USB IDs +KATAPULT_USB_ID = "1d50:6177" +KLIPPER_USB_ID = "1d50:614e" +SERIAL_BL_REQ = b"~ \x1c Request Serial Bootloader!! ~" + class FlashError(Exception): pass +def get_usb_info(usb_path: pathlib.Path) -> Dict[str, Any]: + usb_info: Dict[str, Any] = {} + id_path = usb_path.joinpath("idVendor") + prod_path = usb_path.joinpath("idProduct") + mfr_path = usb_path.joinpath("manufacturer") + if id_path.is_file() and prod_path.is_file(): + vid = id_path.read_text().strip().lower() + pid = prod_path.read_text().strip().lower() + usb_info["usb_id"] = f"{vid}:{pid}" + usb_info["manufacturer"] = "unknown" + if mfr_path.is_file(): + usb_info["manufacturer"] = mfr_path.read_text().strip().lower() + return usb_info + +def get_usb_path(device: pathlib.Path) -> Optional[pathlib.Path]: + device_path = device.resolve() + if not device_path.exists(): + return None + sys_dev_path = pathlib.Path("/sys/class/tty").joinpath(device_path.name) + if not sys_dev_path.exists(): + return None + sys_dev_path = sys_dev_path.resolve() + for usb_path in sys_dev_path.parents: + dnum_file = usb_path.joinpath("devnum") + bnum_file = usb_path.joinpath("busnum") + if dnum_file.is_file() and bnum_file.is_file(): + return usb_path + return None + +def get_stable_usb_symlink(device: pathlib.Path) -> pathlib.Path: + device_path = device.resolve() + ser_by_path_dir = pathlib.Path("/dev/serial/by-path") + try: + if ser_by_path_dir.exists(): + for item in ser_by_path_dir.iterdir(): + if item.samefile(device_path): + return item + except OSError: + pass + return device_path + class CanFlasher: def __init__( self, @@ -461,8 +509,6 @@ class CanSocket: self.output_busy = False def _jump_to_bootloader(self, uuid: int): - # TODO: Send Klipper Admin command to jump to bootloader. - # It will need to be implemented output_line("Sending bootloader jump command...") plist = [(uuid >> ((5 - i) * 8)) & 0xFF for i in range(6)] plist.insert(0, KLIPPER_REBOOT_CMD) @@ -650,19 +696,89 @@ class SerialSocket: ) raise FlashError(f"Serial device {dev_path} in use") - async def run(self, intf: str, baud: int, fw_path: pathlib.Path) -> None: - if not fw_path.is_file(): - raise FlashError("Invalid firmware path '%s'" % (fw_path)) - await self.validate_device(intf) + async def _request_usb_bootloader(self, device: pathlib.Path) -> pathlib.Path: + output_line(f"Requesting USB bootloader for {device}...") + usb_dev_path = get_usb_path(device) + if usb_dev_path is None: + output_line(f"Device path {device} is not a usb device") + return device + stable_path = get_stable_usb_symlink(device) + fd: Optional[int] = None + with contextlib.suppress(OSError): + fd = os.open(str(device), os.O_RDWR) + fcntl.ioctl( + fd, termios.TIOCMBIS, struct.pack('I', termios.TIOCM_DTR) + ) + t = termios.tcgetattr(fd) + t[4] = t[5] = termios.B1200 + termios.tcsetattr(fd, termios.TCSANOW, t) + fcntl.ioctl( + fd, termios.TIOCMBIC, struct.pack('I', termios.TIOCM_DTR) + ) + if fd is not None: + os.close(fd) + output("Waiting for USB Reconnect.") + for _ in range(8): + await asyncio.sleep(.5) + output(".") + usb_info = get_usb_info(usb_dev_path) + mfr = usb_info.get("manufacturer") + if mfr == "katapult": + output_line("Katapult detected") + await asyncio.sleep(1.0) + break + else: + output_line("timed out") + return stable_path + + async def _request_serial_bootloader(self, device: str, baud: int) -> None: + output_line(f"Requesting serial bootloader for device {device}...") + self.serial = self._open_device(device, baud) + self.send(0, SERIAL_BL_REQ) + await asyncio.sleep(1.) + if self.serial is not None: + self.close() + + def _open_device(self, device: str, baud: int) -> Serial: try: serial_dev = Serial( # type: ignore baudrate=baud, timeout=0, exclusive=True ) - serial_dev.port = intf + serial_dev.port = device serial_dev.open() except (OSError, IOError, SerialException) as e: raise FlashError("Unable to open serial port: %s" % (e,)) - self.serial = serial_dev + return serial_dev + + async def run( + self, intf: str, baud: int, fw_path: pathlib.Path, req_only: bool + ) -> None: + if not fw_path.is_file(): + raise FlashError("Invalid firmware path '%s'" % (fw_path)) + await self.validate_device(intf) + intf_path = pathlib.Path(intf) + usb_dev_path = get_usb_path(intf_path) + dev_info: Dict[str, Any] = {} + if usb_dev_path is not None: + dev_info = get_usb_info(usb_dev_path) + usb_id = dev_info.get("usb_id") + usb_mfr = dev_info.get("manufacturer") + if usb_mfr == "klipper" or usb_id == KLIPPER_USB_ID: + # Request usb bootloader, wait for katapult + output_line("Detected USB device running Klipper") + new_intf = await self._request_usb_bootloader(intf_path) + intf = str(new_intf) + if req_only: + return + elif usb_mfr == "katapult" or usb_id == KATAPULT_USB_ID: + output_line("Detected USB device running Katapult") + if req_only: + return + elif req_only: + # Request serial bootloader and exit + await self._request_serial_bootloader(intf, baud) + return + self.serial = self._open_device(intf, baud) self._loop.add_reader(self.serial.fileno(), self._handle_response) flasher = CanFlasher(self.node, fw_path) try: @@ -720,10 +836,10 @@ async def main(args: argparse.Namespace) -> int: ) output_line(f"Flashing Serial Device {args.device}, baud {args.baud}") sock = SerialSocket(loop) - await sock.run(args.device, args.baud, fpath) + await sock.run(args.device, args.baud, fpath, req_only) except Exception: logging.exception("Flash Error") - return -1 + return 1 finally: if sock is not None: sock.close() @@ -767,7 +883,7 @@ if __name__ == '__main__': ) parser.add_argument( "-r", "--request-bootloader", action="store_true", - help="Requests the bootloader and exits (CAN only)" + help="Requests the bootloader and exits" ) args = parser.parse_args()