// Copyright 2019 Robin Krahl // SPDX-License-Identifier: GPL-3.0-or-later use usb_device::bus::{InterfaceNumber, StringIndex, UsbBus, UsbBusAllocator}; use usb_device::class::{ControlIn, ControlOut, UsbClass}; use usb_device::control::{Recipient, Request, RequestType}; use usb_device::descriptor::DescriptorWriter; use usb_device::endpoint::{EndpointAddress, EndpointIn}; use usb_device::{Result, UsbError}; const SPECIFICATION_RELEASE: u16 = 0x111; const INTERFACE_CLASS_HID: u8 = 0x03; enum_u8! { #[derive(Clone, Copy, Debug, PartialEq)] pub enum Subclass { None = 0x00, BootInterface = 0x01, } } enum_u8! { #[derive(Clone, Copy, Debug, PartialEq)] pub enum Protocol { None = 0x00, Keyboard = 0x01, Mouse = 0x02, } } enum_u8! { #[derive(Debug, Clone, Copy)] pub enum DescriptorType { Hid = 0x21, Report = 0x22, Physical = 0x23, } } pub trait HidDevice { fn subclass(&self) -> Subclass; fn protocol(&self) -> Protocol; fn report_descriptor(&self) -> &[u8]; } pub struct HidClass<'a, B: UsbBus, D: HidDevice> { device: D, interface: InterfaceNumber, endpoint_interrupt_in: EndpointIn<'a, B>, expect_interrupt_in_complete: bool, } impl HidClass<'_, B, D> { pub fn new(device: D, alloc: &UsbBusAllocator) -> HidClass<'_, B, D> { HidClass { device, interface: alloc.interface(), endpoint_interrupt_in: alloc.interrupt(8, 10), expect_interrupt_in_complete: false, } } fn get_report_descriptor(&self, index: u8, buf: &mut [u8]) -> Result { if index == 0 { let report_descriptor = self.device.report_descriptor(); let len = report_descriptor.len(); if len > buf.len() { Err(UsbError::BufferOverflow) } else { buf[0..len].copy_from_slice(report_descriptor); Ok(len) } } else { Err(UsbError::InvalidState) } } } impl UsbClass for HidClass<'_, B, D> { fn poll(&mut self) {} fn reset(&mut self) { self.expect_interrupt_in_complete = false; } fn get_configuration_descriptors(&self, writer: &mut DescriptorWriter) -> Result<()> { writer.interface( self.interface, INTERFACE_CLASS_HID, self.device.subclass().into(), self.device.protocol().into(), )?; let report_descriptor = self.device.report_descriptor(); let descriptor_len = report_descriptor.len(); if descriptor_len > u16::max_value() as usize { return Err(UsbError::InvalidState); } let descriptor_len = (descriptor_len as u16).to_le_bytes(); let specification_release = SPECIFICATION_RELEASE.to_le_bytes(); writer.write( DescriptorType::Hid.into(), &[ specification_release[0], // bcdHID.lower specification_release[1], // bcdHID.upper 0, // bCountryCode: 0 = not supported 1, // bNumDescriptors DescriptorType::Report.into(), // bDescriptorType descriptor_len[0], // bDescriptorLength.lower descriptor_len[1], // bDescriptorLength.upper ], )?; writer.endpoint(&self.endpoint_interrupt_in)?; Ok(()) } fn get_string(&self, _index: StringIndex, _lang_id: u16) -> Option<&str> { None } fn endpoint_in_complete(&mut self, addr: EndpointAddress) { if addr == self.endpoint_interrupt_in.address() { if self.expect_interrupt_in_complete { self.expect_interrupt_in_complete = false; } else { panic!("unexpected endpoint_in_complete"); } } } fn endpoint_out(&mut self, _addr: EndpointAddress) {} fn control_in(&mut self, xfer: ControlIn) { let req = xfer.request(); if req.request_type == RequestType::Standard && req.recipient == Recipient::Interface && req.request == Request::GET_DESCRIPTOR { let (dtype, index) = req.descriptor_type_index(); if dtype == DescriptorType::Report.into() { xfer.accept(|mut buf| self.get_report_descriptor(index, &mut buf)) .ok(); } } } fn control_out(&mut self, _xfer: ControlOut) {} }