mirror of
https://github.com/andreili/katapult.git
synced 2025-08-23 11:24:06 +02:00
flashtool: prime usb connections with double buffering
Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
542b6a8519
commit
2005ca5b13
@ -96,13 +96,17 @@ def get_usb_info(usb_path: pathlib.Path) -> Dict[str, Any]:
|
||||
id_path = usb_path.joinpath("idVendor")
|
||||
prod_path = usb_path.joinpath("idProduct")
|
||||
mfr_path = usb_path.joinpath("manufacturer")
|
||||
prod_path = usb_path.joinpath("product")
|
||||
if id_path.is_file() and prod_path.is_file():
|
||||
vid = id_path.read_text().strip().lower()
|
||||
pid = prod_path.read_text().strip().lower()
|
||||
usb_info["usb_id"] = f"{vid}:{pid}"
|
||||
usb_info["manufacturer"] = "unknown"
|
||||
usb_info["product"] = "unknown"
|
||||
if mfr_path.is_file():
|
||||
usb_info["manufacturer"] = mfr_path.read_text().strip().lower()
|
||||
if prod_path.is_file():
|
||||
usb_info["product"] = prod_path.read_text().strip().lower()
|
||||
return usb_info
|
||||
|
||||
def get_usb_path(device: pathlib.Path) -> Optional[pathlib.Path]:
|
||||
@ -141,6 +145,7 @@ class CanFlasher:
|
||||
self.node = node
|
||||
self.firmware_path = fw_file
|
||||
self.fw_sha = hashlib.sha1()
|
||||
self.primed = False
|
||||
self.file_size = 0
|
||||
self.block_size = 64
|
||||
self.block_count = 0
|
||||
@ -173,6 +178,26 @@ class CanFlasher:
|
||||
f"Detected Klipper binary version {ver}, MCU: {bin_mcu}"
|
||||
)
|
||||
|
||||
def _build_command(self, cmd: int, payload: bytes) -> bytearray:
|
||||
word_cnt = (len(payload) // 4) & 0xFF
|
||||
out_cmd = bytearray(CMD_HEADER)
|
||||
out_cmd.append(cmd)
|
||||
out_cmd.append(word_cnt)
|
||||
if payload:
|
||||
out_cmd.extend(payload)
|
||||
crc = crc16_ccitt(out_cmd[2:])
|
||||
out_cmd.extend(struct.pack("<H", crc))
|
||||
out_cmd.extend(CMD_TRAILER)
|
||||
return out_cmd
|
||||
|
||||
def prime(self) -> None:
|
||||
# Prime with an invalid command. This will generate an error
|
||||
# and force double buffered USB devices to respond after the
|
||||
# first command is sent.
|
||||
msg = self._build_command(0x90, b"")
|
||||
self.node.write(msg)
|
||||
self.primed = True
|
||||
|
||||
async def connect_btl(self) -> None:
|
||||
output_line("Attempting to connect to bootloader")
|
||||
ret = await self.send_command('CONNECT')
|
||||
@ -224,16 +249,8 @@ class CanFlasher:
|
||||
payload: bytes = b"",
|
||||
tries: int = 5
|
||||
) -> bytearray:
|
||||
word_cnt = (len(payload) // 4) & 0xFF
|
||||
cmd = BOOTLOADER_CMDS[cmdname]
|
||||
out_cmd = bytearray(CMD_HEADER)
|
||||
out_cmd.append(cmd)
|
||||
out_cmd.append(word_cnt)
|
||||
if payload:
|
||||
out_cmd.extend(payload)
|
||||
crc = crc16_ccitt(out_cmd[2:])
|
||||
out_cmd.extend(struct.pack("<H", crc))
|
||||
out_cmd.extend(CMD_TRAILER)
|
||||
out_cmd = self._build_command(cmd, payload)
|
||||
last_err = Exception()
|
||||
while tries:
|
||||
data = bytearray()
|
||||
@ -242,7 +259,7 @@ class CanFlasher:
|
||||
self.node.write(out_cmd)
|
||||
read_done = False
|
||||
while not read_done:
|
||||
ret = await self.node.readuntil()
|
||||
ret = await self.node.readuntil(CMD_TRAILER)
|
||||
data.extend(ret)
|
||||
while len(data) > 7:
|
||||
if data[:2] != CMD_HEADER:
|
||||
@ -251,6 +268,11 @@ class CanFlasher:
|
||||
recd_len = data[3] * 4
|
||||
read_done = len(data) == recd_len + 8
|
||||
break
|
||||
if self.primed and read_done:
|
||||
recd_len = 0
|
||||
data.clear()
|
||||
self.primed = False
|
||||
read_done = False
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
@ -750,6 +772,12 @@ class SerialSocket:
|
||||
raise FlashError("Unable to open serial port: %s" % (e,))
|
||||
return serial_dev
|
||||
|
||||
def _has_double_buffering(self, product: str) -> bool:
|
||||
if product.startswith("stm32"):
|
||||
variant = product[5:7]
|
||||
return variant not in ("f2", "f4", "h7")
|
||||
return False
|
||||
|
||||
async def run(
|
||||
self, intf: str, baud: int, fw_path: pathlib.Path, req_only: bool
|
||||
) -> None:
|
||||
@ -763,6 +791,7 @@ class SerialSocket:
|
||||
dev_info = get_usb_info(usb_dev_path)
|
||||
usb_id = dev_info.get("usb_id")
|
||||
usb_mfr = dev_info.get("manufacturer")
|
||||
usb_prod: str = dev_info.get("product", "unknown")
|
||||
if usb_mfr == "klipper" or usb_id == KLIPPER_USB_ID:
|
||||
# Request usb bootloader, wait for katapult
|
||||
output_line("Detected USB device running Klipper")
|
||||
@ -778,10 +807,17 @@ class SerialSocket:
|
||||
# Request serial bootloader and exit
|
||||
await self._request_serial_bootloader(intf, baud)
|
||||
return
|
||||
else:
|
||||
usb_prod = ""
|
||||
self.serial = self._open_device(intf, baud)
|
||||
self._loop.add_reader(self.serial.fileno(), self._handle_response)
|
||||
flasher = CanFlasher(self.node, fw_path)
|
||||
try:
|
||||
if self._has_double_buffering(usb_prod):
|
||||
# Prime the USB Connection with a dummy command. This is
|
||||
# necessary to get STM32 devices with usbfs double buffering
|
||||
# to respond immediately to the connect command.
|
||||
flasher.prime()
|
||||
await flasher.connect_btl()
|
||||
await flasher.send_file()
|
||||
await flasher.verify_file()
|
||||
|
Loading…
x
Reference in New Issue
Block a user