// Copyright 2019 Robin Krahl // SPDX-License-Identifier: GPL-3.0-or-later use usb_device::bus::{InterfaceNumber, StringIndex, UsbBus, UsbBusAllocator}; use usb_device::class::{ControlIn, ControlOut, UsbClass}; use usb_device::control; use usb_device::control::{Recipient, RequestType}; use usb_device::descriptor::DescriptorWriter; use usb_device::endpoint::{EndpointAddress, EndpointIn}; use usb_device::UsbError; use crate::util::TryFrom; const SPECIFICATION_RELEASE: u16 = 0x111; const INTERFACE_CLASS_HID: u8 = 0x03; enum_u8! { #[derive(Clone, Copy, Debug, PartialEq)] pub enum Subclass { None = 0x00, BootInterface = 0x01, } } enum_u8! { #[derive(Clone, Copy, Debug, PartialEq)] pub enum Protocol { None = 0x00, Keyboard = 0x01, Mouse = 0x02, } } enum_u8! { #[derive(Debug, Clone, Copy)] pub enum DescriptorType { Hid = 0x21, Report = 0x22, Physical = 0x23, } } enum_u8! { #[derive(Debug, Clone, Copy, PartialEq)] pub enum Request { GetReport = 0x01, GetIdle = 0x02, GetProtocol = 0x03, SetReport = 0x09, SetIdle = 0x0a, SetProtocol = 0x0b, } } #[derive(Debug, Clone, Copy, PartialEq)] pub enum ReportType { Input, Output, Feature, Reserved(u8), } impl From for ReportType { fn from(val: u8) -> Self { match val { 1 => ReportType::Input, 2 => ReportType::Output, 3 => ReportType::Feature, _ => ReportType::Reserved(val), } } } pub trait HidDevice { fn subclass(&self) -> Subclass; fn protocol(&self) -> Protocol; fn report_descriptor(&self) -> &[u8]; fn set_report(&mut self, report_type: ReportType, report_id: u8, data: &[u8]) -> Result<(), ()>; fn get_report(&mut self, report_type: ReportType, report_id: u8) -> Result<&[u8], ()>; } pub struct HidClass<'a, B: UsbBus, D: HidDevice> { device: D, interface: InterfaceNumber, endpoint_interrupt_in: EndpointIn<'a, B>, expect_interrupt_in_complete: bool, } impl HidClass<'_, B, D> { pub fn new(device: D, alloc: &UsbBusAllocator) -> HidClass<'_, B, D> { HidClass { device, interface: alloc.interface(), endpoint_interrupt_in: alloc.interrupt(8, 10), expect_interrupt_in_complete: false, } } fn get_report(&mut self, xfer: ControlIn) { let req = xfer.request(); let [report_type, report_id] = req.value.to_be_bytes(); let report_type = ReportType::from(report_type); match self.device.get_report(report_type, report_id) { Ok(data) => xfer.accept_with(data).ok(), Err(()) => xfer.reject().ok(), }; } fn set_report(&mut self, xfer: ControlOut) { let req = xfer.request(); let [report_type, report_id] = req.value.to_be_bytes(); let report_type = ReportType::from(report_type); match self.device.set_report(report_type, report_id, xfer.data()) { Ok(()) => xfer.accept().ok(), Err(()) => xfer.reject().ok(), }; } } impl UsbClass for HidClass<'_, B, D> { fn poll(&mut self) {} fn reset(&mut self) { self.expect_interrupt_in_complete = false; } fn get_configuration_descriptors( &self, writer: &mut DescriptorWriter, ) -> usb_device::Result<()> { writer.interface( self.interface, INTERFACE_CLASS_HID, self.device.subclass().into(), self.device.protocol().into(), )?; let report_descriptor = self.device.report_descriptor(); let descriptor_len = report_descriptor.len(); if descriptor_len > u16::max_value() as usize { return Err(UsbError::InvalidState); } let descriptor_len = (descriptor_len as u16).to_le_bytes(); let specification_release = SPECIFICATION_RELEASE.to_le_bytes(); writer.write( DescriptorType::Hid.into(), &[ specification_release[0], // bcdHID.lower specification_release[1], // bcdHID.upper 0, // bCountryCode: 0 = not supported 1, // bNumDescriptors DescriptorType::Report.into(), // bDescriptorType descriptor_len[0], // bDescriptorLength.lower descriptor_len[1], // bDescriptorLength.upper ], )?; writer.endpoint(&self.endpoint_interrupt_in)?; Ok(()) } fn get_string(&self, _index: StringIndex, _lang_id: u16) -> Option<&str> { None } fn endpoint_in_complete(&mut self, addr: EndpointAddress) { if addr == self.endpoint_interrupt_in.address() { if self.expect_interrupt_in_complete { self.expect_interrupt_in_complete = false; } else { panic!("unexpected endpoint_in_complete"); } } } fn endpoint_out(&mut self, _addr: EndpointAddress) {} fn control_in(&mut self, xfer: ControlIn) { let req = xfer.request(); match (req.request_type, req.recipient) { (RequestType::Standard, Recipient::Interface) => { if req.request == control::Request::GET_DESCRIPTOR { let (dtype, index) = req.descriptor_type_index(); if dtype == DescriptorType::Report.into() && index == 0 { let descriptor = self.device.report_descriptor(); xfer.accept_with(descriptor).ok(); } } } (RequestType::Class, Recipient::Interface) => { if let Ok(request) = Request::try_from(req.request) { if request == Request::GetReport { self.get_report(xfer); } } } _ => {} } } fn control_out(&mut self, xfer: ControlOut) { let req = xfer.request(); if req.request_type == RequestType::Class && req.recipient == Recipient::Interface { if let Ok(request) = Request::try_from(req.request) { if request == Request::SetReport { self.set_report(xfer); } } } } }