flashtool: prime usb connections with double buffering

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2024-09-25 10:57:15 -04:00
parent 542b6a8519
commit 2005ca5b13

View File

@ -96,13 +96,17 @@ def get_usb_info(usb_path: pathlib.Path) -> Dict[str, Any]:
id_path = usb_path.joinpath("idVendor") id_path = usb_path.joinpath("idVendor")
prod_path = usb_path.joinpath("idProduct") prod_path = usb_path.joinpath("idProduct")
mfr_path = usb_path.joinpath("manufacturer") mfr_path = usb_path.joinpath("manufacturer")
prod_path = usb_path.joinpath("product")
if id_path.is_file() and prod_path.is_file(): if id_path.is_file() and prod_path.is_file():
vid = id_path.read_text().strip().lower() vid = id_path.read_text().strip().lower()
pid = prod_path.read_text().strip().lower() pid = prod_path.read_text().strip().lower()
usb_info["usb_id"] = f"{vid}:{pid}" usb_info["usb_id"] = f"{vid}:{pid}"
usb_info["manufacturer"] = "unknown" usb_info["manufacturer"] = "unknown"
usb_info["product"] = "unknown"
if mfr_path.is_file(): if mfr_path.is_file():
usb_info["manufacturer"] = mfr_path.read_text().strip().lower() usb_info["manufacturer"] = mfr_path.read_text().strip().lower()
if prod_path.is_file():
usb_info["product"] = prod_path.read_text().strip().lower()
return usb_info return usb_info
def get_usb_path(device: pathlib.Path) -> Optional[pathlib.Path]: def get_usb_path(device: pathlib.Path) -> Optional[pathlib.Path]:
@ -141,6 +145,7 @@ class CanFlasher:
self.node = node self.node = node
self.firmware_path = fw_file self.firmware_path = fw_file
self.fw_sha = hashlib.sha1() self.fw_sha = hashlib.sha1()
self.primed = False
self.file_size = 0 self.file_size = 0
self.block_size = 64 self.block_size = 64
self.block_count = 0 self.block_count = 0
@ -173,6 +178,26 @@ class CanFlasher:
f"Detected Klipper binary version {ver}, MCU: {bin_mcu}" f"Detected Klipper binary version {ver}, MCU: {bin_mcu}"
) )
def _build_command(self, cmd: int, payload: bytes) -> bytearray:
word_cnt = (len(payload) // 4) & 0xFF
out_cmd = bytearray(CMD_HEADER)
out_cmd.append(cmd)
out_cmd.append(word_cnt)
if payload:
out_cmd.extend(payload)
crc = crc16_ccitt(out_cmd[2:])
out_cmd.extend(struct.pack("<H", crc))
out_cmd.extend(CMD_TRAILER)
return out_cmd
def prime(self) -> None:
# Prime with an invalid command. This will generate an error
# and force double buffered USB devices to respond after the
# first command is sent.
msg = self._build_command(0x90, b"")
self.node.write(msg)
self.primed = True
async def connect_btl(self) -> None: async def connect_btl(self) -> None:
output_line("Attempting to connect to bootloader") output_line("Attempting to connect to bootloader")
ret = await self.send_command('CONNECT') ret = await self.send_command('CONNECT')
@ -224,16 +249,8 @@ class CanFlasher:
payload: bytes = b"", payload: bytes = b"",
tries: int = 5 tries: int = 5
) -> bytearray: ) -> bytearray:
word_cnt = (len(payload) // 4) & 0xFF
cmd = BOOTLOADER_CMDS[cmdname] cmd = BOOTLOADER_CMDS[cmdname]
out_cmd = bytearray(CMD_HEADER) out_cmd = self._build_command(cmd, payload)
out_cmd.append(cmd)
out_cmd.append(word_cnt)
if payload:
out_cmd.extend(payload)
crc = crc16_ccitt(out_cmd[2:])
out_cmd.extend(struct.pack("<H", crc))
out_cmd.extend(CMD_TRAILER)
last_err = Exception() last_err = Exception()
while tries: while tries:
data = bytearray() data = bytearray()
@ -242,7 +259,7 @@ class CanFlasher:
self.node.write(out_cmd) self.node.write(out_cmd)
read_done = False read_done = False
while not read_done: while not read_done:
ret = await self.node.readuntil() ret = await self.node.readuntil(CMD_TRAILER)
data.extend(ret) data.extend(ret)
while len(data) > 7: while len(data) > 7:
if data[:2] != CMD_HEADER: if data[:2] != CMD_HEADER:
@ -251,6 +268,11 @@ class CanFlasher:
recd_len = data[3] * 4 recd_len = data[3] * 4
read_done = len(data) == recd_len + 8 read_done = len(data) == recd_len + 8
break break
if self.primed and read_done:
recd_len = 0
data.clear()
self.primed = False
read_done = False
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise
except asyncio.TimeoutError: except asyncio.TimeoutError:
@ -750,6 +772,12 @@ class SerialSocket:
raise FlashError("Unable to open serial port: %s" % (e,)) raise FlashError("Unable to open serial port: %s" % (e,))
return serial_dev return serial_dev
def _has_double_buffering(self, product: str) -> bool:
if product.startswith("stm32"):
variant = product[5:7]
return variant not in ("f2", "f4", "h7")
return False
async def run( async def run(
self, intf: str, baud: int, fw_path: pathlib.Path, req_only: bool self, intf: str, baud: int, fw_path: pathlib.Path, req_only: bool
) -> None: ) -> None:
@ -763,6 +791,7 @@ class SerialSocket:
dev_info = get_usb_info(usb_dev_path) dev_info = get_usb_info(usb_dev_path)
usb_id = dev_info.get("usb_id") usb_id = dev_info.get("usb_id")
usb_mfr = dev_info.get("manufacturer") usb_mfr = dev_info.get("manufacturer")
usb_prod: str = dev_info.get("product", "unknown")
if usb_mfr == "klipper" or usb_id == KLIPPER_USB_ID: if usb_mfr == "klipper" or usb_id == KLIPPER_USB_ID:
# Request usb bootloader, wait for katapult # Request usb bootloader, wait for katapult
output_line("Detected USB device running Klipper") output_line("Detected USB device running Klipper")
@ -778,10 +807,17 @@ class SerialSocket:
# Request serial bootloader and exit # Request serial bootloader and exit
await self._request_serial_bootloader(intf, baud) await self._request_serial_bootloader(intf, baud)
return return
else:
usb_prod = ""
self.serial = self._open_device(intf, baud) self.serial = self._open_device(intf, baud)
self._loop.add_reader(self.serial.fileno(), self._handle_response) self._loop.add_reader(self.serial.fileno(), self._handle_response)
flasher = CanFlasher(self.node, fw_path) flasher = CanFlasher(self.node, fw_path)
try: try:
if self._has_double_buffering(usb_prod):
# Prime the USB Connection with a dummy command. This is
# necessary to get STM32 devices with usbfs double buffering
# to respond immediately to the connect command.
flasher.prime()
await flasher.connect_btl() await flasher.connect_btl()
await flasher.send_file() await flasher.send_file()
await flasher.verify_file() await flasher.verify_file()