From 9d894c71906b9f0e0f62db33089b722b1ae382b5 Mon Sep 17 00:00:00 2001 From: Eric Callahan Date: Wed, 13 Apr 2022 15:20:22 -0400 Subject: [PATCH] flash_can: bring module in line with latest mcu code Signed-off-by: Eric Callahan --- scripts/flash_can.py | 363 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 288 insertions(+), 75 deletions(-) diff --git a/scripts/flash_can.py b/scripts/flash_can.py index 186f1c1..1c4f988 100644 --- a/scripts/flash_can.py +++ b/scripts/flash_can.py @@ -1,98 +1,117 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # Script to upload software via Can Bootloader # -# Copyright (C) 2021 Eric Callahan +# Copyright (C) 2022 Eric Callahan # # This file may be distributed under the terms of the GNU GPLv3 license. +from __future__ import annotations import sys import os -import time +import asyncio +import socket +import struct +import logging +import errno import argparse import hashlib -import traceback -import serial +import pathlib +from typing import Dict, List, Optional, Union -def output_line(msg): +def output_line(msg: str) -> None: sys.stdout.write(msg + "\n") sys.stdout.flush() -def output(msg): +def output(msg: str) -> None: sys.stdout.write(msg) sys.stdout.flush() +logging.basicConfig(level=logging.DEBUG) +CAN_FMT = " None: + self.node = node self.fw_name = fw_file self.fw_sha = hashlib.sha1() self.file_size = 0 self.block_size = 64 self.block_count = 0 - def connect_btl(self): + async def connect_btl(self): output_line("Attempting to connect to bootloader") - self.block_size = self.send_command('CONNECT') + self.block_size = await self.send_command('CONNECT') if self.block_size not in [64, 128, 256, 512]: raise FlashCanError("Invalid Block Size: %d" % (self.block_size,)) output_line("Connected, Block Size: %d bytes" % (self.block_size,)) - def serial_read(self, count): - ret = b"" - tries = 10 - while tries: - data = self.ser.read(count) - count -= len(data) - ret += data - if not count: - break - tries -= 1 - return ret - - def send_command(self, cmdname, ack=ACK_CMD, arg=0, tries=5): + async def send_command( + self, + cmdname: str, + ack: bytearray = ACK_CMD, + arg: int = 0, + tries: int = 5 + ) -> int: cmd = BOOTLOADER_CMDS[cmdname] - out_cmd = bytearray(CMD_HEADER + cmd) + out_cmd = bytearray(CMD_HEADER) + out_cmd.append(cmd) out_cmd.append((arg >> 8) & 0xFF) out_cmd.append(arg & 0xFF) out_cmd.append(CMD_TRAILER) while tries: try: - self.ser.write(out_cmd) - data = bytearray(self.serial_read(8)) + ret = await self.node.write_with_response(out_cmd, 8) + data = bytearray(ret) except Exception: - traceback.print_exc() + logging.exception("Can Read Error") else: - if len(data) == 8 and data.startswith(ack) and \ - data.endswith(CMD_TRAILER): + if ( + len(data) == 8 and + data[:5] == ack and + data[-1] == CMD_TRAILER + ): return (data[5] << 8) | data[6] tries -= 1 - time.sleep(.1) + await asyncio.sleep(.1) raise FlashCanError("Error sending command [%s] to Can Device" % (cmdname)) - def send_file(self): + async def send_file(self): last_percent = 0 output_line("Flashing '%s'..." % (self.fw_name)) output("\n[") @@ -104,17 +123,17 @@ class FlashCan: buf = f.read(self.block_size) if not buf: break - self.send_command('SEND_BLOCK', arg=self.block_count) + await self.send_command('SEND_BLOCK', arg=self.block_count) if len(buf) < self.block_size: buf += b"\xFF" * (self.block_size - len(buf)) self.fw_sha.update(buf) - self.ser.write(buf) - ack = self.serial_read(8) + self.node.write(buf) + ack = await self.node.readexactly(8) expect = bytearray(ACK_BLOCK_RECD) expect.append((self.block_count >> 8) & 0xFF) expect.append(self.block_count & 0xFF) expect.append(CMD_TRAILER) - if ack != expect: + if ack != bytes(expect): output_line("\nExpected resp: %s, Recd: %s" % (expect, ack)) raise FlashCanError("Did not receive ACK for sent block %d" % (self.block_count)) @@ -124,10 +143,10 @@ class FlashCan: if pct >= last_percent + 2: last_percent += 2. output("#") - page_count = self.send_command('SEND_EOF') + page_count = await self.send_command('SEND_EOF') output_line("]\n\nWrite complete: %d pages" % (page_count)) - def verify_file(self): + async def verify_file(self): last_percent = 0 output_line("Verifying (block count = %d)..." % (self.block_count,)) output("\n[") @@ -135,15 +154,15 @@ class FlashCan: for i in range(self.block_count): tries = 3 while tries: - resp = self.send_command("REQUEST_BLOCK", arg=i) + resp = await self.send_command("REQUEST_BLOCK", arg=i) if resp == i: # command should ack with the requested block as # parameter - buf = self.serial_read(self.block_size) + buf = await self.node.readexactly(self.block_size, timeout=10.) if len(buf) == self.block_size: break tries -= 1 - time.sleep(.1) + await asyncio.sleep(.1) else: output_line("Error") raise FlashCanError("Block Request Error, block: %d" % (i,)) @@ -159,37 +178,231 @@ class FlashCan: % (fw_hex, ver_hex)) output_line("]\n\nVerification Complete: SHA = %s" % (ver_hex)) - def finish(self): - self.send_command("COMPLETE") + async def finish(self): + await self.send_command("COMPLETE") + + +class CanNode: + def __init__(self, node_id: int, cansocket: CanSocket) -> None: + self.node_id = node_id + self._reader = asyncio.StreamReader(CAN_READER_LIMIT) + self._cansocket = cansocket + + async def read( + self, n: int = -1, timeout: Optional[float] = 2 + ) -> bytes: + return await asyncio.wait_for(self._reader.read(n), timeout) + + async def readexactly( + self, n: int, timeout: Optional[float] = 2 + ) -> bytes: + return await asyncio.wait_for(self._reader.readexactly(n), timeout) + + def write(self, payload: Union[bytes, bytearray]) -> None: + if isinstance(payload, bytearray): + payload = bytes(payload) + self._cansocket.send(self.node_id, payload) + + async def write_with_response( + self, + payload: Union[bytearray, bytes], + resp_length: int, + timeout: Optional[float] = 2. + ) -> bytes: + self.write(payload) + return await self.readexactly(resp_length, timeout) + + def feed_data(self, data: bytes) -> None: + self._reader.feed_data(data) + + def close(self) -> None: + self._reader.feed_eof() + +class CanSocket: + def __init__(self, loop: asyncio.AbstractEventLoop): + self._loop = loop + self.cansock = socket.socket(socket.PF_CAN, socket.SOCK_RAW, + socket.CAN_RAW) + self.admin_node = CanNode(CANBUS_ID_ADMIN, self) + self.nodes: Dict[int, CanNode] = { + CANBUS_ID_ADMIN_RESP: self.admin_node + } + + self.input_buffer = b"" + self.output_packets: List[bytes] = [] + self.input_busy = False + self.output_busy = False + self.closed = True + + def _handle_can_response(self) -> None: + try: + data = self.cansock.recv(4096) + except socket.error as e: + # If bad file descriptor allow connection to be + # closed by the data check + if e.errno == errno.EBADF: + logging.exception("Can Socket Read Error, closing") + data = b'' + else: + return + if not data: + # socket closed + self.close() + return + self.input_buffer += data + if self.input_busy: + return + self.input_busy = True + while len(self.input_buffer) >= 16: + packet = self.input_buffer[:16] + self._process_packet(packet) + self.input_buffer = self.input_buffer[16:] + self.input_busy = False + + def _process_packet(self, packet: bytes) -> None: + can_id, length, data = struct.unpack(CAN_FMT, packet) + can_id &= socket.CAN_EFF_MASK + payload = data[:length] + node = self.nodes.get(can_id) + if node is not None: + node.feed_data(payload) + + def send(self, can_id: int, payload: bytes = b"") -> None: + if can_id > 0x7FF: + can_id |= socket.CAN_EFF_FLAG + if not payload: + packet = struct.pack(CAN_FMT, can_id, 0, b"") + self.output_packets.append(packet) + else: + while payload: + length = min(len(payload), 8) + pkt_data = payload[:length] + payload = payload[length:] + packet = struct.pack( + CAN_FMT, can_id, length, pkt_data) + self.output_packets.append(packet) + if self.output_busy: + return + self.output_busy = True + asyncio.create_task(self._do_can_send()) + + async def _do_can_send(self): + while self.output_packets: + packet = self.output_packets.pop(0) + try: + await self._loop.sock_sendall(self.cansock, packet) + except socket.error: + logging.info("Socket Write Error, closing") + self.close() + break + self.output_busy = False + + def _jump_to_bootloader(self, uuid: int): + # TODO: Send Klipper Admin command to jump to bootloader. + # It will need to be implemented + output_line("Sending bootloader jump command...") + plist = [(uuid >> ((5 - i) * 8)) & 0xFF for i in range(6)] + plist.insert(0, KLIPPER_REBOOT_CMD) + self.send(KLIPPER_ADMIN_ID, bytes(plist)) + + async def _query_uuids(self) -> List[int]: + output_line("Checking for canboot nodes...") + payload = bytes([CANBUS_CMD_QUERY_UNASSIGNED]) + self.admin_node.write(payload) + curtime = self._loop.time() + endtime = curtime + 2. + self.uuids: List[int] = [] + while curtime < endtime: + diff = endtime - curtime + try: + resp = await self.admin_node.readexactly(7, diff) + except asyncio.TimeoutError: + break + finally: + curtime = self._loop.time() + if resp[0] != CANBUS_RESP_NEED_NODEID: + continue + output_line(f"Detected UUID: {resp[1:].hex()}") + uuid = sum([v << ((5 - i) * 8) for i, v in enumerate(resp[1:7])]) + if uuid not in self.uuids: + self.uuids.append(uuid) + return self.uuids + + def _reset_nodes(self) -> None: + output_line("Resetting all bootloader node IDs...") + payload = bytes([CANBUS_CMD_CLEAR_NODE_ID]) + self.admin_node.write(payload) + + def _set_node_id(self, uuid: int) -> CanNode: + # Convert ID to a list + plist = [(uuid >> ((5 - i) * 8)) & 0xFF for i in range(6)] + plist.insert(0, CANBUS_CMD_SET_NODEID) + node_id = len(self.nodes) + plist.append(node_id) + payload = bytes(plist) + self.admin_node.write(payload) + decoded_id = node_id * 2 + CANBUS_NODE_OFFSET + node = CanNode(decoded_id, self) + self.nodes[decoded_id + 1] = node + return node + + async def run(self, uuid: int, fw_path: pathlib.Path): + if not fw_path.is_file(): + raise FlashCanError("Invalid firmware path '%s'" % (fw_path)) + try: + self.cansock.bind(("can0",)) + except Exception: + raise FlashCanError("Unable to bind socket to can0") + self.closed = False + self.cansock.setblocking(False) + self._loop.add_reader( + self.cansock.fileno(), self._handle_can_response) + self._jump_to_bootloader(uuid) + await asyncio.sleep(.5) + self._reset_nodes() + await asyncio.sleep(.5) + id_list = await self._query_uuids() + if uuid not in id_list: + raise FlashCanError( + f"Unable to find node matching UUID: {uuid}" + ) + node = self._set_node_id(uuid) + flasher = CanFlasher(node, fw_path) + await asyncio.sleep(.5) + await flasher.connect_btl() + await flasher.send_file() + await flasher.verify_file() + await flasher.finish() def close(self): - self.ser.close() + if self.closed: + return + self.closed = True + for node in self.nodes.values(): + node.close() + self._loop.remove_reader(self.cansock.fileno()) + self.cansock.close() def main(): parser = argparse.ArgumentParser( description="Can Bootloader Flash Utility") - parser.add_argument('device', metavar="", - help="Can device location") + parser.add_argument('uuid', metavar="", + help="Can device uuid") parser.add_argument('fwpath', metavar="", help="Path to firmware file") args = parser.parse_args() - fcan = None - dev = args.device - fname = os.path.abspath(os.path.expanduser(args.fwpath)) + uuid = int(args.uuid, 16) + fpath = pathlib.Path(args.fwpath).expanduser().resolve() + loop = asyncio.get_event_loop() try: - fcan = FlashCan(dev, fname) - time.sleep(1.) - fcan.connect_btl() - fcan.send_file() - fcan.verify_file() - fcan.finish() + cansock = CanSocket(loop) + loop.run_until_complete(cansock.run(uuid, fpath)) except Exception as e: - output("Can Flash Error: ") - output_line(str(e)) + logging.exception("Can Flash Error") sys.exit(-1) finally: - if fcan is not None: - fcan.close() + if cansock is not None: + cansock.close() output_line("CAN Flash Success")