flashtool: support usb/serial bootloader requests

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2024-09-20 19:18:06 -04:00
parent dd8b0a0c9a
commit 542b6a8519

View File

@ -7,6 +7,8 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
import os import os
import termios
import fcntl
import zlib import zlib
import json import json
import asyncio import asyncio
@ -81,9 +83,55 @@ CANBUS_CMD_CLEAR_NODE_ID = 0x12
CANBUS_RESP_NEED_NODEID = 0x20 CANBUS_RESP_NEED_NODEID = 0x20
CANBUS_NODEID_OFFSET = 128 CANBUS_NODEID_OFFSET = 128
# USB IDs
KATAPULT_USB_ID = "1d50:6177"
KLIPPER_USB_ID = "1d50:614e"
SERIAL_BL_REQ = b"~ \x1c Request Serial Bootloader!! ~"
class FlashError(Exception): class FlashError(Exception):
pass pass
def get_usb_info(usb_path: pathlib.Path) -> Dict[str, Any]:
usb_info: Dict[str, Any] = {}
id_path = usb_path.joinpath("idVendor")
prod_path = usb_path.joinpath("idProduct")
mfr_path = usb_path.joinpath("manufacturer")
if id_path.is_file() and prod_path.is_file():
vid = id_path.read_text().strip().lower()
pid = prod_path.read_text().strip().lower()
usb_info["usb_id"] = f"{vid}:{pid}"
usb_info["manufacturer"] = "unknown"
if mfr_path.is_file():
usb_info["manufacturer"] = mfr_path.read_text().strip().lower()
return usb_info
def get_usb_path(device: pathlib.Path) -> Optional[pathlib.Path]:
device_path = device.resolve()
if not device_path.exists():
return None
sys_dev_path = pathlib.Path("/sys/class/tty").joinpath(device_path.name)
if not sys_dev_path.exists():
return None
sys_dev_path = sys_dev_path.resolve()
for usb_path in sys_dev_path.parents:
dnum_file = usb_path.joinpath("devnum")
bnum_file = usb_path.joinpath("busnum")
if dnum_file.is_file() and bnum_file.is_file():
return usb_path
return None
def get_stable_usb_symlink(device: pathlib.Path) -> pathlib.Path:
device_path = device.resolve()
ser_by_path_dir = pathlib.Path("/dev/serial/by-path")
try:
if ser_by_path_dir.exists():
for item in ser_by_path_dir.iterdir():
if item.samefile(device_path):
return item
except OSError:
pass
return device_path
class CanFlasher: class CanFlasher:
def __init__( def __init__(
self, self,
@ -461,8 +509,6 @@ class CanSocket:
self.output_busy = False self.output_busy = False
def _jump_to_bootloader(self, uuid: int): def _jump_to_bootloader(self, uuid: int):
# TODO: Send Klipper Admin command to jump to bootloader.
# It will need to be implemented
output_line("Sending bootloader jump command...") output_line("Sending bootloader jump command...")
plist = [(uuid >> ((5 - i) * 8)) & 0xFF for i in range(6)] plist = [(uuid >> ((5 - i) * 8)) & 0xFF for i in range(6)]
plist.insert(0, KLIPPER_REBOOT_CMD) plist.insert(0, KLIPPER_REBOOT_CMD)
@ -650,19 +696,89 @@ class SerialSocket:
) )
raise FlashError(f"Serial device {dev_path} in use") raise FlashError(f"Serial device {dev_path} in use")
async def run(self, intf: str, baud: int, fw_path: pathlib.Path) -> None: async def _request_usb_bootloader(self, device: pathlib.Path) -> pathlib.Path:
if not fw_path.is_file(): output_line(f"Requesting USB bootloader for {device}...")
raise FlashError("Invalid firmware path '%s'" % (fw_path)) usb_dev_path = get_usb_path(device)
await self.validate_device(intf) if usb_dev_path is None:
output_line(f"Device path {device} is not a usb device")
return device
stable_path = get_stable_usb_symlink(device)
fd: Optional[int] = None
with contextlib.suppress(OSError):
fd = os.open(str(device), os.O_RDWR)
fcntl.ioctl(
fd, termios.TIOCMBIS, struct.pack('I', termios.TIOCM_DTR)
)
t = termios.tcgetattr(fd)
t[4] = t[5] = termios.B1200
termios.tcsetattr(fd, termios.TCSANOW, t)
fcntl.ioctl(
fd, termios.TIOCMBIC, struct.pack('I', termios.TIOCM_DTR)
)
if fd is not None:
os.close(fd)
output("Waiting for USB Reconnect.")
for _ in range(8):
await asyncio.sleep(.5)
output(".")
usb_info = get_usb_info(usb_dev_path)
mfr = usb_info.get("manufacturer")
if mfr == "katapult":
output_line("Katapult detected")
await asyncio.sleep(1.0)
break
else:
output_line("timed out")
return stable_path
async def _request_serial_bootloader(self, device: str, baud: int) -> None:
output_line(f"Requesting serial bootloader for device {device}...")
self.serial = self._open_device(device, baud)
self.send(0, SERIAL_BL_REQ)
await asyncio.sleep(1.)
if self.serial is not None:
self.close()
def _open_device(self, device: str, baud: int) -> Serial:
try: try:
serial_dev = Serial( # type: ignore serial_dev = Serial( # type: ignore
baudrate=baud, timeout=0, exclusive=True baudrate=baud, timeout=0, exclusive=True
) )
serial_dev.port = intf serial_dev.port = device
serial_dev.open() serial_dev.open()
except (OSError, IOError, SerialException) as e: except (OSError, IOError, SerialException) as e:
raise FlashError("Unable to open serial port: %s" % (e,)) raise FlashError("Unable to open serial port: %s" % (e,))
self.serial = serial_dev return serial_dev
async def run(
self, intf: str, baud: int, fw_path: pathlib.Path, req_only: bool
) -> None:
if not fw_path.is_file():
raise FlashError("Invalid firmware path '%s'" % (fw_path))
await self.validate_device(intf)
intf_path = pathlib.Path(intf)
usb_dev_path = get_usb_path(intf_path)
dev_info: Dict[str, Any] = {}
if usb_dev_path is not None:
dev_info = get_usb_info(usb_dev_path)
usb_id = dev_info.get("usb_id")
usb_mfr = dev_info.get("manufacturer")
if usb_mfr == "klipper" or usb_id == KLIPPER_USB_ID:
# Request usb bootloader, wait for katapult
output_line("Detected USB device running Klipper")
new_intf = await self._request_usb_bootloader(intf_path)
intf = str(new_intf)
if req_only:
return
elif usb_mfr == "katapult" or usb_id == KATAPULT_USB_ID:
output_line("Detected USB device running Katapult")
if req_only:
return
elif req_only:
# Request serial bootloader and exit
await self._request_serial_bootloader(intf, baud)
return
self.serial = self._open_device(intf, baud)
self._loop.add_reader(self.serial.fileno(), self._handle_response) self._loop.add_reader(self.serial.fileno(), self._handle_response)
flasher = CanFlasher(self.node, fw_path) flasher = CanFlasher(self.node, fw_path)
try: try:
@ -720,10 +836,10 @@ async def main(args: argparse.Namespace) -> int:
) )
output_line(f"Flashing Serial Device {args.device}, baud {args.baud}") output_line(f"Flashing Serial Device {args.device}, baud {args.baud}")
sock = SerialSocket(loop) sock = SerialSocket(loop)
await sock.run(args.device, args.baud, fpath) await sock.run(args.device, args.baud, fpath, req_only)
except Exception: except Exception:
logging.exception("Flash Error") logging.exception("Flash Error")
return -1 return 1
finally: finally:
if sock is not None: if sock is not None:
sock.close() sock.close()
@ -767,7 +883,7 @@ if __name__ == '__main__':
) )
parser.add_argument( parser.add_argument(
"-r", "--request-bootloader", action="store_true", "-r", "--request-bootloader", action="store_true",
help="Requests the bootloader and exits (CAN only)" help="Requests the bootloader and exits"
) )
args = parser.parse_args() args = parser.parse_args()