#!/usr/bin/env python3 # Script to upload software via Katapult # # Copyright (C) 2022 Eric Callahan # # This file may be distributed under the terms of the GNU GPLv3 license. from __future__ import annotations import sys import os import asyncio import socket import struct import logging import errno import argparse import hashlib import pathlib import shutil import shlex from typing import Dict, List, Optional, Union HAS_SERIAL = True try: from serial import Serial, SerialException except ModuleNotFoundError: HAS_SERIAL = False SerialException = Exception def output_line(msg: str) -> None: sys.stdout.write(msg + "\n") sys.stdout.flush() def output(msg: str) -> None: sys.stdout.write(msg) sys.stdout.flush() # Standard crc16 ccitt, take from msgproto.py in Klipper def crc16_ccitt(buf: Union[bytes, bytearray]) -> int: crc = 0xffff for data in buf: data ^= crc & 0xff data ^= (data & 0x0f) << 4 crc = ((data << 8) | (crc >> 8)) ^ (data >> 4) ^ (data << 3) return crc & 0xFFFF logging.basicConfig(level=logging.INFO) CAN_FMT = " None: self.node = node self.fw_name = fw_file self.fw_sha = hashlib.sha1() self.file_size = 0 self.block_size = 64 self.block_count = 0 self.app_start_addr = 0 async def connect_btl(self) -> None: output_line("Attempting to connect to bootloader") ret = await self.send_command('CONNECT') pinfo = ret[:12] mcu_info = ret[12:] ver_bytes: bytes ver_bytes, start_addr, self.block_size = struct.unpack("<4sII", pinfo) self.app_start_addr = start_addr self.software_version = "?" self.proto_version = tuple([v for v in reversed(ver_bytes[:3])]) proto_version_str = ".".join([str(v) for v in self.proto_version]) if self.block_size not in [64, 128, 256, 512]: raise FlashError("Invalid Block Size: %d" % (self.block_size,)) mcu_info.rstrip(b"\x00") if self.proto_version >= (1, 1, 0): mcu_bytes, sv_bytes = mcu_info.split(b"\x00", maxsplit=1) mcu_type = mcu_bytes.decode() self.software_version = sv_bytes.decode() else: mcu_type = mcu_info.decode() output_line( f"Katapult Connected\n" f"Software Version: {self.software_version}\n" f"Protocol Version: {proto_version_str}\n" f"Block Size: {self.block_size} bytes\n" f"Application Start: 0x{self.app_start_addr:4X}\n" f"MCU type: {mcu_type}" ) async def verify_canbus_uuid(self, uuid): output_line("Verifying canbus connection") ret = await self.send_command('GET_CANBUS_ID') mcu_uuid = sum([v << ((5 - i) * 8) for i, v in enumerate(ret[:6])]) if mcu_uuid != uuid: raise FlashError("UUID mismatch (%s vs %s)" % (uuid, mcu_uuid)) async def send_command( self, cmdname: str, payload: bytes = b"", tries: int = 5 ) -> bytearray: word_cnt = (len(payload) // 4) & 0xFF cmd = BOOTLOADER_CMDS[cmdname] out_cmd = bytearray(CMD_HEADER) out_cmd.append(cmd) out_cmd.append(word_cnt) if payload: out_cmd.extend(payload) crc = crc16_ccitt(out_cmd[2:]) out_cmd.extend(struct.pack(" 7: if data[:2] != CMD_HEADER: data = data[1:] continue recd_len = data[3] * 4 read_done = len(data) == recd_len + 8 break except asyncio.CancelledError: raise except asyncio.TimeoutError: logging.info( f"Response for command {cmdname} timed out, " f"{tries - 1} tries remaining" ) except Exception as e: if type(e) is type(last_err) and e.args == last_err.args: last_err = e logging.exception("Device Read Error") else: trailer = data[-2:] recd_crc, = struct.unpack("= last_percent + 2: last_percent += 2. output("#") resp = await self.send_command('SEND_EOF') page_count, = struct.unpack("= last_percent + 2: last_percent += 2 output("#") ver_hex = ver_sha.hexdigest().upper() fw_hex = self.fw_sha.hexdigest().upper() if ver_hex != fw_hex: raise FlashError("Checksum mismatch: Expected %s, Received %s" % (fw_hex, ver_hex)) output_line("]\n\nVerification Complete: SHA = %s" % (ver_hex)) async def finish(self): await self.send_command("COMPLETE") class CanNode: def __init__(self, node_id: int, cansocket: CanSocket | SerialSocket) -> None: self.node_id = node_id self._reader = asyncio.StreamReader(CAN_READER_LIMIT) self._cansocket = cansocket async def read( self, n: int = -1, timeout: Optional[float] = 2 ) -> bytes: return await asyncio.wait_for(self._reader.read(n), timeout) async def readexactly( self, n: int, timeout: Optional[float] = 2 ) -> bytes: return await asyncio.wait_for(self._reader.readexactly(n), timeout) async def readuntil( self, sep: bytes = b"\x03", timeout: Optional[float] = 2 ) -> bytes: return await asyncio.wait_for(self._reader.readuntil(sep), timeout) def write(self, payload: Union[bytes, bytearray]) -> None: if isinstance(payload, bytearray): payload = bytes(payload) self._cansocket.send(self.node_id, payload) async def write_with_response( self, payload: Union[bytearray, bytes], resp_length: int, timeout: Optional[float] = 2. ) -> bytes: self.write(payload) return await self.readexactly(resp_length, timeout) def feed_data(self, data: bytes) -> None: self._reader.feed_data(data) def close(self) -> None: self._reader.feed_eof() class CanSocket: def __init__(self, loop: asyncio.AbstractEventLoop): self._loop = loop self.cansock = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) self.admin_node = CanNode(CANBUS_ID_ADMIN, self) self.nodes: Dict[int, CanNode] = { CANBUS_ID_ADMIN_RESP: self.admin_node } self.input_buffer = b"" self.output_packets: List[bytes] = [] self.input_busy = False self.output_busy = False self.closed = True def _handle_can_response(self) -> None: try: data = self.cansock.recv(4096) except socket.error as e: # If bad file descriptor allow connection to be # closed by the data check if e.errno == errno.EBADF: logging.exception("Can Socket Read Error, closing") data = b'' else: return if not data: # socket closed self.close() return self.input_buffer += data if self.input_busy: return self.input_busy = True while len(self.input_buffer) >= 16: packet = self.input_buffer[:16] self._process_packet(packet) self.input_buffer = self.input_buffer[16:] self.input_busy = False def _process_packet(self, packet: bytes) -> None: can_id, length, data = struct.unpack(CAN_FMT, packet) can_id &= socket.CAN_EFF_MASK payload = data[:length] node = self.nodes.get(can_id) if node is not None: node.feed_data(payload) def send(self, can_id: int, payload: bytes = b"") -> None: if can_id > 0x7FF: can_id |= socket.CAN_EFF_FLAG if not payload: packet = struct.pack(CAN_FMT, can_id, 0, b"") self.output_packets.append(packet) else: while payload: length = min(len(payload), 8) pkt_data = payload[:length] payload = payload[length:] packet = struct.pack( CAN_FMT, can_id, length, pkt_data) self.output_packets.append(packet) if self.output_busy: return self.output_busy = True asyncio.create_task(self._do_can_send()) async def _do_can_send(self): while self.output_packets: packet = self.output_packets.pop(0) try: await self._loop.sock_sendall(self.cansock, packet) except socket.error: logging.info("Socket Write Error, closing") self.close() break self.output_busy = False def _jump_to_bootloader(self, uuid: int): # TODO: Send Klipper Admin command to jump to bootloader. # It will need to be implemented output_line("Sending bootloader jump command...") plist = [(uuid >> ((5 - i) * 8)) & 0xFF for i in range(6)] plist.insert(0, KLIPPER_REBOOT_CMD) self.send(KLIPPER_ADMIN_ID, bytes(plist)) async def _query_uuids(self) -> List[int]: output_line("Checking for Katapult nodes...") payload = bytes([CANBUS_CMD_QUERY_UNASSIGNED]) self.admin_node.write(payload) curtime = self._loop.time() endtime = curtime + 2. self.uuids: List[int] = [] while curtime < endtime: timeout = max(.1, endtime - curtime) try: resp = await self.admin_node.read(8, timeout) except asyncio.TimeoutError: continue finally: curtime = self._loop.time() if len(resp) < 7 or resp[0] != CANBUS_RESP_NEED_NODEID: continue app_names = { KLIPPER_SET_NODE_CMD: "Klipper", CANBUS_CMD_SET_NODEID: "Katapult" } app = "Unknown" if len(resp) > 7: app = app_names.get(resp[7], "Unknown") data = resp[1:7] output_line(f"Detected UUID: {data.hex()}, Application: {app}") uuid = sum([v << ((5 - i) * 8) for i, v in enumerate(data)]) if uuid not in self.uuids and app == "Katapult": self.uuids.append(uuid) return self.uuids def _reset_nodes(self) -> None: output_line("Resetting all bootloader node IDs...") payload = bytes([CANBUS_CMD_CLEAR_NODE_ID]) self.admin_node.write(payload) def _set_node_id(self, uuid: int) -> CanNode: # Convert ID to a list plist = [(uuid >> ((5 - i) * 8)) & 0xFF for i in range(6)] plist.insert(0, CANBUS_CMD_SET_NODEID) node_id = len(self.nodes) + CANBUS_NODEID_OFFSET plist.append(node_id) payload = bytes(plist) self.admin_node.write(payload) decoded_id = node_id * 2 + 0x100 node = CanNode(decoded_id, self) self.nodes[decoded_id + 1] = node return node async def run( self, intf: str, uuid: int, fw_path: pathlib.Path, req_only: bool ) -> None: if not req_only and not fw_path.is_file(): raise FlashError("Invalid firmware path '%s'" % (fw_path)) try: self.cansock.bind((intf,)) except Exception: raise FlashError("Unable to bind socket to can0") self.closed = False self.cansock.setblocking(False) self._loop.add_reader( self.cansock.fileno(), self._handle_can_response) self._jump_to_bootloader(uuid) if req_only: output_line("Bootloader request command sent") return await asyncio.sleep(.5) self._reset_nodes() await asyncio.sleep(1.0) node = self._set_node_id(uuid) flasher = CanFlasher(node, fw_path) await asyncio.sleep(.5) try: await flasher.connect_btl() await flasher.verify_canbus_uuid(uuid) await flasher.send_file() await flasher.verify_file() finally: # always attempt to send the complete command. If # there is an error it will exit the bootloader # unless comms were broken await flasher.finish() async def run_query(self, intf: str): try: self.cansock.bind((intf,)) except Exception: raise FlashError("Unable to bind socket to can0") self.closed = False self.cansock.setblocking(False) self._loop.add_reader( self.cansock.fileno(), self._handle_can_response) self._reset_nodes() await asyncio.sleep(.5) await self._query_uuids() def close(self): if self.closed: return self.closed = True for node in self.nodes.values(): node.close() self._loop.remove_reader(self.cansock.fileno()) self.cansock.close() class SerialSocket: def __init__(self, loop: asyncio.AbstractEventLoop): self._loop = loop self.serial: Optional[Serial] = None self.node = CanNode(0, self) def _handle_response(self) -> None: assert self.serial is not None try: data = self.serial.read(4096) except SerialException: logging.exception("Error on serial read") self.close() else: self.node.feed_data(data) def send(self, can_id: int, payload: bytes = b"") -> None: assert self.serial is not None try: self.serial.write(payload) except SerialException: logging.exception("Error on serial write") self.close() async def _lookup_proc_name(self, process_id: str) -> str: has_sysctl = shutil.which("systemctl") is not None if has_sysctl: cmd = shlex.split(f"systemctl status {process_id}") proc = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, _ = await proc.communicate() resp = stdout.strip().decode(errors="ignore") if resp: unit = resp.split(maxsplit=2) if len(unit) == 3: return f"Systemd Unit Name: {unit[1]}" cmdline_file = pathlib.Path(f"/proc/{process_id}/cmdline") if cmdline_file.exists(): res = cmdline_file.read_text().replace("\x00", " ").strip() return f"Command Line: {res}" exe_file = pathlib.Path(f"/proc/{process_id}/exe") if exe_file.exists(): return f"Executable: {exe_file.resolve()})" return "Name Unknown" async def validate_device(self, dev_strpath: str) -> None: dev_path = pathlib.Path(dev_strpath) if not dev_path.exists(): raise FlashError(f"No Serial Device found at {dev_path}") try: dev_st = dev_path.stat() except PermissionError as e: raise FlashError(f"No permission to access device {dev_path}") from e dev_id = (dev_st.st_dev, dev_st.st_ino) for fd_dir in pathlib.Path("/proc").glob("*/fd"): pid = fd_dir.parent.name if not pid.isdigit(): continue try: for item in fd_dir.iterdir(): item_st = item.stat() item_id = (item_st.st_dev, item_st.st_ino) if item_id == dev_id: proc_name = await self._lookup_proc_name(pid) output_line( f"Serial device {dev_path} in use by another program.\n" f"Process ID: {pid}\n" f"Process {proc_name}" ) raise FlashError(f"Serial device {dev_path} in use") except PermissionError: continue async def run(self, intf: str, baud: int, fw_path: pathlib.Path) -> None: if not fw_path.is_file(): raise FlashError("Invalid firmware path '%s'" % (fw_path)) await self.validate_device(intf) try: serial_dev = Serial( # type: ignore baudrate=baud, timeout=0, exclusive=True ) serial_dev.port = intf serial_dev.open() except (OSError, IOError, SerialException) as e: raise FlashError("Unable to open serial port: %s" % (e,)) self.serial = serial_dev self._loop.add_reader(self.serial.fileno(), self._handle_response) flasher = CanFlasher(self.node, fw_path) try: await flasher.connect_btl() await flasher.send_file() await flasher.verify_file() finally: # always attempt to send the complete command. If # there is an error it will exit the bootloader # unless comms were broken await flasher.finish() def close(self): if self.serial is None: return self._loop.remove_reader(self.serial.fileno()) self.serial.close() self.serial = None def main(): parser = argparse.ArgumentParser( description="Katapult Flash Tool") parser.add_argument( "-d", "--device", metavar='', help="Serial Device" ) parser.add_argument( "-b", "--baud", default=250000, metavar='', help="Serial baud rate" ) parser.add_argument( "-i", "--interface", default="can0", metavar='', help="Can Interface" ) parser.add_argument( "-f", "--firmware", metavar="", default="~/klipper/out/klipper.bin", help="Path to Klipper firmware file") parser.add_argument( "-u", "--uuid", metavar="", default=None, help="Can device uuid" ) parser.add_argument( "-q", "--query", action="store_true", help="Query Bootloader Device IDs" ) parser.add_argument( "-v", "--verbose", action="store_true", help="Enable verbose responses" ) parser.add_argument( "-r", "--request-bootloader", action="store_true", help="Requests the bootloader and exits (CAN only)" ) args = parser.parse_args() if not args.verbose: logging.getLogger().setLevel(logging.ERROR) intf = args.interface fpath = pathlib.Path(args.firmware).expanduser().resolve() loop = asyncio.get_event_loop() iscan = args.device is None req_only = args.request_bootloader sock = None try: if iscan: sock = CanSocket(loop) if args.query: loop.run_until_complete(sock.run_query(intf)) else: if args.uuid is None: raise FlashError( "The 'uuid' option must be specified to flash a device" ) output_line(f"Flashing CAN UUID {args.uuid} on interface {intf}") uuid = int(args.uuid, 16) loop.run_until_complete(sock.run(intf, uuid, fpath, req_only)) else: if not HAS_SERIAL: ser_inst_cmd = "pip3 install serial" if shutil.which("apt") is not None: ser_inst_cmd = "sudo apt install python3-serial" raise FlashError( "The pyserial python package was not found. To install " "run the following command in a terminal: \n\n" f" {ser_inst_cmd}\n\n" ) if args.device is None: raise FlashError( "The 'device' option must be specified to flash a device" ) output_line(f"Flashing Serial Device {args.device}, baud {args.baud}") sock = SerialSocket(loop) loop.run_until_complete(sock.run(args.device, args.baud, fpath)) except Exception: logging.exception("Flash Error") sys.exit(-1) finally: if sock is not None: sock.close() if args.query: output_line("Query Complete") else: output_line("Flash Success") if __name__ == '__main__': main()