diff --git a/scripts/flashtool.py b/scripts/flashtool.py index 755fb7b..2f2a5b0 100755 --- a/scripts/flashtool.py +++ b/scripts/flashtool.py @@ -61,7 +61,7 @@ BOOTLOADER_CMDS = { 'SEND_EOF': 0x13, 'REQUEST_BLOCK': 0x14, 'COMPLETE': 0x15, - 'GET_CANBUS_ID': 0x16, + 'GET_CANBUS_ID': 0x16 } ACK_SUCCESS = 0xa0 @@ -158,7 +158,7 @@ class CanFlasher: Extract klipper.dict from binary """ fw_name = self.firmware_path.name.lower() - if fw_name != "klipper.bin": + if fw_name != "klipper.bin" or not self.firmware_path.is_file(): return bin_data = self.firmware_path.read_bytes() klipper_dict: Dict[str, Any] = {} @@ -451,9 +451,54 @@ class CanNode: def close(self) -> None: self._reader.feed_eof() -class CanSocket: - def __init__(self, loop: asyncio.AbstractEventLoop): - self._loop = loop +class BaseSocket: + def __init__(self, args: argparse.Namespace) -> None: + self._loop = asyncio.get_running_loop() + self._args = args + self._fw_path = pathlib.Path(args.firmware).expanduser().resolve() + + @property + def is_flash_req(self) -> bool: + return not ( + self.is_bootloader_req or self.is_status_req or self.is_query + ) + + @property + def is_bootloader_req(self) -> bool: + return self._args.request_bootloader + + @property + def is_status_req(self) -> bool: + return self._args.status + + @property + def is_query(self) -> bool: + return self._args.query + + def _check_firmware(self) -> None: + if self.is_flash_req and not self._fw_path.is_file(): + raise FlashError("Invalid firmware path '%s'" % (self._fw_path)) + + async def run(self) -> None: + raise NotImplementedError() + + def close(self) -> None: + raise NotImplementedError() + +class CanSocket(BaseSocket): + def __init__(self, args: argparse.Namespace) -> None: + super().__init__(args) + self._uuid = 0 + self._can_interface = args.interface + if not self.is_query: + if args.uuid is None: + raise FlashError( + "The 'uuid' option must be specified to flash a CAN device" + ) + else: + intf = self._can_interface + self._uuid = int(args.uuid, 16) + output_line(f"Connecting to CAN UUID {args.uuid} on interface {intf}") self.cansock = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) self.admin_node = CanNode(CANBUS_ID_ADMIN, self) @@ -585,52 +630,41 @@ class CanSocket: self.nodes[decoded_id + 1] = node return node - async def run( - self, intf: str, uuid: int, fw_path: pathlib.Path, req_only: bool - ) -> None: - if not req_only and not fw_path.is_file(): - raise FlashError("Invalid firmware path '%s'" % (fw_path)) + async def run(self) -> None: + self._check_firmware() try: - self.cansock.bind((intf,)) + self.cansock.bind((self._can_interface,)) except Exception: - raise FlashError("Unable to bind socket to can0") + raise FlashError(f"Unable to bind socket to {self._can_interface}") self.closed = False self.cansock.setblocking(False) self._loop.add_reader( self.cansock.fileno(), self._handle_can_response) - self._jump_to_bootloader(uuid) - await asyncio.sleep(.5) - if req_only: - output_line("Bootloader request command sent") - return + if self.is_flash_req or self.is_bootloader_req: + self._jump_to_bootloader(self._uuid) + await asyncio.sleep(1.0) + if self.is_bootloader_req: + return self._reset_nodes() - await asyncio.sleep(1.0) - node = self._set_node_id(uuid) - flasher = CanFlasher(node, fw_path) + await asyncio.sleep(.5) + if self.is_query: + await self._query_uuids() + return + node = self._set_node_id(self._uuid) + flasher = CanFlasher(node, self._fw_path) await asyncio.sleep(.5) try: await flasher.connect_btl() - await flasher.verify_canbus_uuid(uuid) - await flasher.send_file() - await flasher.verify_file() + await flasher.verify_canbus_uuid(self._uuid) + if not self.is_status_req: + await flasher.send_file() + await flasher.verify_file() finally: # always attempt to send the complete command. If # there is an error it will exit the bootloader # unless comms were broken - await flasher.finish() - - async def run_query(self, intf: str): - try: - self.cansock.bind((intf,)) - except Exception: - raise FlashError("Unable to bind socket to can0") - self.closed = False - self.cansock.setblocking(False) - self._loop.add_reader( - self.cansock.fileno(), self._handle_can_response) - self._reset_nodes() - await asyncio.sleep(.5) - await self._query_uuids() + if self.is_flash_req: + await flasher.finish() def close(self): if self.closed: @@ -641,12 +675,32 @@ class CanSocket: self._loop.remove_reader(self.cansock.fileno()) self.cansock.close() -class SerialSocket: - def __init__(self, loop: asyncio.AbstractEventLoop): - self._loop = loop +class SerialSocket(BaseSocket): + def __init__(self, args: argparse.Namespace) -> None: + super().__init__(args) + self._device = args.device + self._baud = args.baud + if not HAS_SERIAL: + ser_inst_cmd = "pip3 install serial" + if shutil.which("apt") is not None: + ser_inst_cmd = "sudo apt install python3-serial" + raise FlashError( + "The pyserial python package was not found. To install " + "run the following command in a terminal: \n\n" + f" {ser_inst_cmd}\n\n" + ) + if self._device is None: + raise FlashError( + "The 'device' option must be specified to flash a device" + ) + output_line(f"Connecting to Serial Device {self._device}, baud {self._baud}") self.serial: Optional[Serial] = None self.node = CanNode(0, self) + @property + def is_query(self) -> bool: + return False + def _handle_response(self) -> None: assert self.serial is not None try: @@ -778,14 +832,12 @@ class SerialSocket: return variant not in ("f2", "f4", "h7") return False - async def run( - self, intf: str, baud: int, fw_path: pathlib.Path, req_only: bool - ) -> None: - if not fw_path.is_file(): - raise FlashError("Invalid firmware path '%s'" % (fw_path)) - await self.validate_device(intf) - intf_path = pathlib.Path(intf) - usb_dev_path = get_usb_path(intf_path) + async def run(self) -> None: + self._check_firmware() + device = self._device + await self.validate_device(device) + dev_path = pathlib.Path(device) + usb_dev_path = get_usb_path(dev_path) dev_info: Dict[str, Any] = {} if usb_dev_path is not None: dev_info = get_usb_info(usb_dev_path) @@ -795,23 +847,23 @@ class SerialSocket: if usb_mfr == "klipper" or usb_id == KLIPPER_USB_ID: # Request usb bootloader, wait for katapult output_line("Detected USB device running Klipper") - new_intf = await self._request_usb_bootloader(intf_path) - intf = str(new_intf) - if req_only: + new_dpath = await self._request_usb_bootloader(dev_path) + device = str(new_dpath) + if self.is_bootloader_req: return elif usb_mfr == "katapult" or usb_id == KATAPULT_USB_ID: output_line("Detected USB device running Katapult") - if req_only: + if self.is_bootloader_req: return - elif req_only: + elif self.is_bootloader_req: # Request serial bootloader and exit - await self._request_serial_bootloader(intf, baud) + await self._request_serial_bootloader(device, self._baud) return else: usb_prod = "" - self.serial = self._open_device(intf, baud) + self.serial = self._open_device(device, self._baud) self._loop.add_reader(self.serial.fileno(), self._handle_response) - flasher = CanFlasher(self.node, fw_path) + flasher = CanFlasher(self.node, self._fw_path) try: if self._has_double_buffering(usb_prod): # Prime the USB Connection with a dummy command. This is @@ -819,13 +871,15 @@ class SerialSocket: # to respond immediately to the connect command. flasher.prime() await flasher.connect_btl() - await flasher.send_file() - await flasher.verify_file() + if not self.is_status_req: + await flasher.send_file() + await flasher.verify_file() finally: # always attempt to send the complete command. If # there is an error it will exit the bootloader # unless comms were broken - await flasher.finish() + if self.is_flash_req: + await flasher.finish() def close(self): if self.serial is None: @@ -837,52 +891,30 @@ class SerialSocket: async def main(args: argparse.Namespace) -> int: if not args.verbose: logging.getLogger().setLevel(logging.ERROR) - intf = args.interface - fpath = pathlib.Path(args.firmware).expanduser().resolve() - loop = asyncio.get_running_loop() iscan = args.device is None - req_only = args.request_bootloader sock: CanSocket | SerialSocket | None = None try: if iscan: - sock = CanSocket(loop) - if args.query: - await sock.run_query(intf) - else: - if args.uuid is None: - raise FlashError( - "The 'uuid' option must be specified to flash a device" - ) - output_line(f"Flashing CAN UUID {args.uuid} on interface {intf}") - uuid = int(args.uuid, 16) - await sock.run(intf, uuid, fpath, req_only) + sock = CanSocket(args) else: - if not HAS_SERIAL: - ser_inst_cmd = "pip3 install serial" - if shutil.which("apt") is not None: - ser_inst_cmd = "sudo apt install python3-serial" - raise FlashError( - "The pyserial python package was not found. To install " - "run the following command in a terminal: \n\n" - f" {ser_inst_cmd}\n\n" - ) - if args.device is None: - raise FlashError( - "The 'device' option must be specified to flash a device" - ) - output_line(f"Flashing Serial Device {args.device}, baud {args.baud}") - sock = SerialSocket(loop) - await sock.run(args.device, args.baud, fpath, req_only) + sock = SerialSocket(args) + await sock.run() except Exception: - logging.exception("Flash Error") + logging.exception("Flash Tool Error") return 1 finally: if sock is not None: sock.close() - if args.query: - output_line("Query Complete") + if sock is None: + return 1 + if sock.is_query: + output_line("CANBus UUID Query Complete") + elif sock.is_bootloader_req: + output_line("Bootloader Request Complete") + elif sock.is_status_req: + output_line("Status Request Complete") else: - output_line("Flash Success") + output_line("Programming Complete") return 0 @@ -911,7 +943,7 @@ if __name__ == '__main__': ) parser.add_argument( "-q", "--query", action="store_true", - help="Query Bootloader Device IDs" + help="Query available CAN UUIDs (CANBus Ony)" ) parser.add_argument( "-v", "--verbose", action="store_true", @@ -921,6 +953,9 @@ if __name__ == '__main__': "-r", "--request-bootloader", action="store_true", help="Requests the bootloader and exits" ) - + parser.add_argument( + "-s", "--status", action="store_true", + help="Connect to bootloader and print status" + ) args = parser.parse_args() exit(asyncio.run(main(args)))