From 951268b5f6b687b60f52608993ad14d28cfc2f08 Mon Sep 17 00:00:00 2001
From: Guillaume Endignoux <guillaumee@google.com>
Date: Mon, 9 Mar 2020 16:38:59 +0100
Subject: [PATCH] Add patch for bidirectional USB endpoints to fix panics in
 case of cancellation.

---
 patches/tock/05-usb-cancel.patch | 747 +++++++++++++++++++++++++++++++
 1 file changed, 747 insertions(+)
 create mode 100644 patches/tock/05-usb-cancel.patch

diff --git a/patches/tock/05-usb-cancel.patch b/patches/tock/05-usb-cancel.patch
new file mode 100644
index 0000000..f853971
--- /dev/null
+++ b/patches/tock/05-usb-cancel.patch
@@ -0,0 +1,747 @@
+diff --git a/capsules/src/usb/usbc_client.rs b/capsules/src/usb/usbc_client.rs
+index b0678c23..9fb43781 100644
+--- a/capsules/src/usb/usbc_client.rs
++++ b/capsules/src/usb/usbc_client.rs
+@@ -115,11 +115,11 @@ impl<'a, C: hil::usb::UsbController<'a>> hil::usb::Client<'a> for Client<'a, C>
+         self.client_ctrl.enable();
+ 
+         // Set up a bulk-in endpoint for debugging
+-        self.controller().endpoint_set_buffer(1, self.buffer(1));
++        self.controller().endpoint_set_in_buffer(1, self.buffer(1));
+         self.controller().endpoint_in_enable(TransferType::Bulk, 1);
+ 
+         // Set up a bulk-out endpoint for debugging
+-        self.controller().endpoint_set_buffer(2, self.buffer(2));
++        self.controller().endpoint_set_out_buffer(2, self.buffer(2));
+         self.controller().endpoint_out_enable(TransferType::Bulk, 2);
+     }
+ 
+diff --git a/capsules/src/usb/usbc_client_ctrl.rs b/capsules/src/usb/usbc_client_ctrl.rs
+index 2aaca0cc..5f9b253c 100644
+--- a/capsules/src/usb/usbc_client_ctrl.rs
++++ b/capsules/src/usb/usbc_client_ctrl.rs
+@@ -201,7 +201,7 @@ impl<'a, 'b, C: hil::usb::UsbController<'a>> ClientCtrl<'a, 'b, C> {
+     pub fn enable(&'a self) {
+         // Set up the default control endpoint
+         self.controller
+-            .endpoint_set_buffer(0, &self.ctrl_buffer.buf);
++            .endpoint_set_ctrl_buffer(&self.ctrl_buffer.buf);
+         self.controller
+             .enable_as_device(hil::usb::DeviceSpeed::Full); // must be Full for Bulk transfers
+         self.controller
+diff --git a/capsules/src/usb/usbc_ctap_hid.rs b/capsules/src/usb/usbc_ctap_hid.rs
+index fdf7263a..4b1916cf 100644
+--- a/capsules/src/usb/usbc_ctap_hid.rs
++++ b/capsules/src/usb/usbc_ctap_hid.rs
+@@ -88,8 +88,9 @@ static HID: HIDDescriptor<'static> = HIDDescriptor {
+ pub struct ClientCtapHID<'a, 'b, C: 'a> {
+     client_ctrl: ClientCtrl<'a, 'static, C>,
+ 
+-    // A 64-byte buffer for the endpoint
+-    buffer: Buffer64,
++    // 64-byte buffers for the endpoint
++    in_buffer: Buffer64,
++    out_buffer: Buffer64,
+ 
+     // Interaction with the client
+     client: OptionalCell<&'b dyn CtapUsbClient>,
+@@ -133,7 +134,8 @@ impl<'a, 'b, C: hil::usb::UsbController<'a>> ClientCtapHID<'a, 'b, C> {
+                 LANGUAGES,
+                 STRINGS,
+             ),
+-            buffer: Default::default(),
++            in_buffer: Default::default(),
++            out_buffer: Default::default(),
+             client: OptionalCell::empty(),
+             tx_packet: OptionalCell::empty(),
+             pending_in: Cell::new(false),
+@@ -187,7 +189,7 @@ impl<'a, 'b, C: hil::usb::UsbController<'a>> ClientCtapHID<'a, 'b, C> {
+     fn send_packet_to_client(&'a self) -> bool {
+         // Copy the packet into a buffer to send to the client.
+         let mut buf: [u8; 64] = [0; 64];
+-        for (i, x) in self.buffer.buf.iter().enumerate() {
++        for (i, x) in self.out_buffer.buf.iter().enumerate() {
+             buf[i] = x.get();
+         }
+ 
+@@ -220,11 +222,7 @@ impl<'a, 'b, C: hil::usb::UsbController<'a>> ClientCtapHID<'a, 'b, C> {
+ 
+     fn cancel_in_transaction(&'a self) -> bool {
+         self.tx_packet.take();
+-        let result = self.pending_in.take();
+-        if result {
+-            self.controller().endpoint_cancel_in(1);
+-        }
+-        result
++        self.pending_in.take()
+     }
+ 
+     fn cancel_out_transaction(&'a self) -> bool {
+@@ -243,7 +241,10 @@ impl<'a, 'b, C: hil::usb::UsbController<'a>> hil::usb::Client<'a> for ClientCtap
+         self.client_ctrl.enable();
+ 
+         // Set up the interrupt in-out endpoint
+-        self.controller().endpoint_set_buffer(1, &self.buffer.buf);
++        self.controller()
++            .endpoint_set_in_buffer(1, &self.in_buffer.buf);
++        self.controller()
++            .endpoint_set_out_buffer(1, &self.out_buffer.buf);
+         self.controller()
+             .endpoint_in_out_enable(TransferType::Interrupt, 1);
+     }
+@@ -293,7 +294,7 @@ impl<'a, 'b, C: hil::usb::UsbController<'a>> hil::usb::Client<'a> for ClientCtap
+                 }
+ 
+                 if let Some(packet) = self.tx_packet.take() {
+-                    let buf = &self.buffer.buf;
++                    let buf = &self.in_buffer.buf;
+                     for i in 0..64 {
+                         buf[i].set(packet[i]);
+                     }
+@@ -346,6 +347,12 @@ impl<'a, 'b, C: hil::usb::UsbController<'a>> hil::usb::Client<'a> for ClientCtap
+             panic!("Unexpected tx_packet while a packet was being transmitted.");
+         }
+         self.pending_in.set(false);
++
++        // Clear any pending packet on the receiving side.
++        // It's up to the client to handle the transmitted packet and decide if they want to
++        // receive another packet.
++        self.cancel_out_transaction();
++
+         // Notify the client
+         self.client.map(|client| client.packet_transmitted());
+     }
+diff --git a/chips/nrf52/src/usbd.rs b/chips/nrf52/src/usbd.rs
+index 8c1992cc..972871d0 100644
+--- a/chips/nrf52/src/usbd.rs
++++ b/chips/nrf52/src/usbd.rs
+@@ -623,7 +623,7 @@ pub enum UsbState {
+ pub enum EndpointState {
+     Disabled,
+     Ctrl(CtrlState),
+-    Bulk(TransferType, EndpointDirection, BulkState),
++    Bulk(TransferType, Option<BulkInState>, Option<BulkOutState>),
+ }
+ 
+ impl EndpointState {
+@@ -634,10 +634,10 @@ impl EndpointState {
+         }
+     }
+ 
+-    fn bulk_state(self) -> (TransferType, EndpointDirection, BulkState) {
++    fn bulk_state(self) -> (TransferType, Option<BulkInState>, Option<BulkOutState>) {
+         match self {
+-            EndpointState::Bulk(transfer_type, direction, state) => {
+-                (transfer_type, direction, state)
++            EndpointState::Bulk(transfer_type, in_state, out_state) => {
++                (transfer_type, in_state, out_state)
+             }
+             _ => panic!("Expected EndpointState::Bulk"),
+         }
+@@ -651,31 +651,18 @@ pub enum CtrlState {
+     ReadStatus,
+ }
+ 
+-#[derive(Copy, Clone, Debug)]
+-pub enum EndpointDirection {
+-    In,
+-    Out,
+-    InOut,
+-}
+-
+-impl EndpointDirection {
+-    fn has_in(&self) -> bool {
+-        match self {
+-            EndpointDirection::In | EndpointDirection::InOut => true,
+-            EndpointDirection::Out => false,
+-        }
+-    }
+-
+-    fn has_out(&self) -> bool {
+-        match self {
+-            EndpointDirection::Out | EndpointDirection::InOut => true,
+-            EndpointDirection::In => false,
+-        }
+-    }
++#[derive(Copy, Clone, PartialEq, Debug)]
++pub enum BulkInState {
++    // The endpoint is ready to perform transactions.
++    Init,
++    // There is a pending DMA transfer on this IN endpoint.
++    InDma,
++    // There is a pending IN packet transfer on this endpoint.
++    InData,
+ }
+ 
+ #[derive(Copy, Clone, PartialEq, Debug)]
+-pub enum BulkState {
++pub enum BulkOutState {
+     // The endpoint is ready to perform transactions.
+     Init,
+     // There is a pending OUT packet in this endpoint's buffer, to be read by
+@@ -685,14 +672,11 @@ pub enum BulkState {
+     OutData,
+     // There is a pending DMA transfer on this OUT endpoint.
+     OutDma,
+-    // There is a pending DMA transfer on this IN endpoint.
+-    InDma,
+-    // There is a pending IN packet transfer on this endpoint.
+-    InData,
+ }
+ 
+ pub struct Endpoint<'a> {
+-    slice: OptionalCell<&'a [VolatileCell<u8>]>,
++    slice_in: OptionalCell<&'a [VolatileCell<u8>]>,
++    slice_out: OptionalCell<&'a [VolatileCell<u8>]>,
+     state: Cell<EndpointState>,
+     // The USB controller can only process one DMA transfer at a time (over all endpoints). The
+     // request_transmit_* bits allow to queue transfers until the DMA becomes available again.
+@@ -705,7 +689,8 @@ pub struct Endpoint<'a> {
+ impl Endpoint<'_> {
+     const fn new() -> Self {
+         Endpoint {
+-            slice: OptionalCell::empty(),
++            slice_in: OptionalCell::empty(),
++            slice_out: OptionalCell::empty(),
+             state: Cell::new(EndpointState::Disabled),
+             request_transmit_in: Cell::new(false),
+             request_transmit_out: Cell::new(false),
+@@ -914,18 +899,12 @@ impl<'a> Usbd<'a> {
+                     chip_revision.get()
+                 );
+             }
+-            Some(ChipRevision::REV::Value::REVC) => {
++            Some(ChipRevision::REV::Value::REVC) | Some(ChipRevision::REV::Value::REVD) => {
+                 debug_info!(
+                     "Your chip is NRF52840 revision {}. The USB stack was tested on your chip :)",
+                     chip_revision.get()
+                 );
+             }
+-            Some(ChipRevision::REV::Value::REVD) => {
+-                internal_warn!(
+-                    "Your chip is NRF52840 revision {}. Although this USB implementation should be compatible, your chip hasn't been tested.",
+-                    chip_revision.get()
+-                );
+-            }
+             None => {
+                 internal_warn!(
+                     "Your chip is NRF52840 revision {} (unknown revision). Although this USB implementation should be compatible, your chip hasn't been tested.",
+@@ -1026,7 +1005,7 @@ impl<'a> Usbd<'a> {
+         });
+         self.descriptors[endpoint].state.set(match endpoint {
+             0 => EndpointState::Ctrl(CtrlState::Init),
+-            1..=7 => EndpointState::Bulk(transfer_type, EndpointDirection::In, BulkState::Init),
++            1..=7 => EndpointState::Bulk(transfer_type, Some(BulkInState::Init), None),
+             8 => unimplemented!("isochronous endpoint"),
+             _ => unreachable!("unexisting endpoint"),
+         });
+@@ -1064,7 +1043,7 @@ impl<'a> Usbd<'a> {
+         });
+         self.descriptors[endpoint].state.set(match endpoint {
+             0 => EndpointState::Ctrl(CtrlState::Init),
+-            1..=7 => EndpointState::Bulk(transfer_type, EndpointDirection::Out, BulkState::Init),
++            1..=7 => EndpointState::Bulk(transfer_type, None, Some(BulkOutState::Init)),
+             8 => unimplemented!("isochronous endpoint"),
+             _ => unreachable!("unexisting endpoint"),
+         });
+@@ -1114,7 +1093,11 @@ impl<'a> Usbd<'a> {
+         });
+         self.descriptors[endpoint].state.set(match endpoint {
+             0 => EndpointState::Ctrl(CtrlState::Init),
+-            1..=7 => EndpointState::Bulk(transfer_type, EndpointDirection::InOut, BulkState::Init),
++            1..=7 => EndpointState::Bulk(
++                transfer_type,
++                Some(BulkInState::Init),
++                Some(BulkOutState::Init),
++            ),
+             8 => unimplemented!("isochronous endpoint"),
+             _ => unreachable!("unexisting endpoint"),
+         });
+@@ -1304,13 +1287,13 @@ impl<'a> Usbd<'a> {
+             match desc.state.get() {
+                 EndpointState::Disabled => {}
+                 EndpointState::Ctrl(_) => desc.state.set(EndpointState::Ctrl(CtrlState::Init)),
+-                EndpointState::Bulk(transfer_type, direction, _) => {
++                EndpointState::Bulk(transfer_type, in_state, out_state) => {
+                     desc.state.set(EndpointState::Bulk(
+                         transfer_type,
+-                        direction,
+-                        BulkState::Init,
++                        in_state.map(|_| BulkInState::Init),
++                        out_state.map(|_| BulkOutState::Init),
+                     ));
+-                    if direction.has_out() {
++                    if out_state.is_some() {
+                         // Accept incoming OUT packets.
+                         regs.size_epout[ep].set(0);
+                     }
+@@ -1347,13 +1330,13 @@ impl<'a> Usbd<'a> {
+         match endpoint {
+             0 => {}
+             1..=7 => {
+-                let (transfer_type, direction, state) =
++                let (transfer_type, in_state, out_state) =
+                     self.descriptors[endpoint].state.get().bulk_state();
+-                assert_eq!(state, BulkState::InDma);
++                assert_eq!(in_state, Some(BulkInState::InDma));
+                 self.descriptors[endpoint].state.set(EndpointState::Bulk(
+                     transfer_type,
+-                    direction,
+-                    BulkState::InData,
++                    Some(BulkInState::InData),
++                    out_state,
+                 ));
+             }
+             8 => unimplemented!("isochronous endpoint"),
+@@ -1405,25 +1388,25 @@ impl<'a> Usbd<'a> {
+             1..=7 => {
+                 // Notify the client about the new packet.
+                 let packet_bytes = regs.size_epout[endpoint].get();
+-                let (transfer_type, direction, state) =
++                let (transfer_type, in_state, out_state) =
+                     self.descriptors[endpoint].state.get().bulk_state();
+-                assert_eq!(state, BulkState::OutDma);
++                assert_eq!(out_state, Some(BulkOutState::OutDma));
+ 
+-                self.debug_packet("out", packet_bytes as usize, endpoint);
++                self.debug_out_packet(packet_bytes as usize, endpoint);
+ 
+                 self.client.map(|client| {
+                     let result = client.packet_out(transfer_type, endpoint, packet_bytes);
+                     debug_packets!("packet_out => {:?}", result);
+-                    let newstate = match result {
++                    let new_out_state = match result {
+                         hil::usb::OutResult::Ok => {
+                             // Indicate that the endpoint is ready to receive data again.
+                             regs.size_epout[endpoint].set(0);
+-                            BulkState::Init
++                            BulkOutState::Init
+                         }
+ 
+                         hil::usb::OutResult::Delay => {
+                             // We can't send the packet now. Wait for a resume_out call from the client.
+-                            BulkState::OutDelay
++                            BulkOutState::OutDelay
+                         }
+ 
+                         hil::usb::OutResult::Error => {
+@@ -1432,13 +1415,13 @@ impl<'a> Usbd<'a> {
+                                     + EndpointStall::IO::Out
+                                     + EndpointStall::STALL::Stall,
+                             );
+-                            BulkState::Init
++                            BulkOutState::Init
+                         }
+                     };
+                     self.descriptors[endpoint].state.set(EndpointState::Bulk(
+                         transfer_type,
+-                        direction,
+-                        newstate,
++                        in_state,
++                        Some(new_out_state),
+                     ));
+                 });
+             }
+@@ -1497,29 +1480,27 @@ impl<'a> Usbd<'a> {
+         // Endpoint 8 (isochronous) doesn't receive any EPDATA event.
+         for endpoint in 1..NUM_ENDPOINTS {
+             if epdatastatus.is_set(status_epin(endpoint)) {
+-                let (transfer_type, direction, state) =
++                let (transfer_type, in_state, out_state) =
+                     self.descriptors[endpoint].state.get().bulk_state();
+-                match state {
+-                    BulkState::InData => {
++                assert!(in_state.is_some());
++                match in_state.unwrap() {
++                    BulkInState::InData => {
+                         // Totally expected state. Nothing to do.
+                     }
+-                    BulkState::Init => {
++                    BulkInState::Init => {
+                         internal_warn!(
+                             "Received a stale epdata IN in an unexpected state: {:?}",
+-                            state
++                            in_state
+                         );
+                     }
+-                    BulkState::OutDelay
+-                    | BulkState::OutData
+-                    | BulkState::OutDma
+-                    | BulkState::InDma => {
+-                        internal_err!("Unexpected state: {:?}", state);
++                    BulkInState::InDma => {
++                        internal_err!("Unexpected state: {:?}", in_state);
+                     }
+                 }
+                 self.descriptors[endpoint].state.set(EndpointState::Bulk(
+                     transfer_type,
+-                    direction,
+-                    BulkState::Init,
++                    Some(BulkInState::Init),
++                    out_state,
+                 ));
+                 self.client
+                     .map(|client| client.packet_transmitted(endpoint));
+@@ -1530,28 +1511,26 @@ impl<'a> Usbd<'a> {
+         // Endpoint 8 (isochronous) doesn't receive any EPDATA event.
+         for ep in 1..NUM_ENDPOINTS {
+             if epdatastatus.is_set(status_epout(ep)) {
+-                let (transfer_type, direction, state) =
++                let (transfer_type, in_state, out_state) =
+                     self.descriptors[ep].state.get().bulk_state();
+-                match state {
+-                    BulkState::Init => {
++                assert!(out_state.is_some());
++                match out_state.unwrap() {
++                    BulkOutState::Init => {
+                         // The endpoint is ready to receive data. Request a transmit_out.
+                         self.descriptors[ep].request_transmit_out.set(true);
+                     }
+-                    BulkState::OutDelay => {
++                    BulkOutState::OutDelay => {
+                         // The endpoint will be resumed later by the client application with transmit_out().
+                     }
+-                    BulkState::OutData
+-                    | BulkState::OutDma
+-                    | BulkState::InDma
+-                    | BulkState::InData => {
+-                        internal_err!("Unexpected state: {:?}", state);
++                    BulkOutState::OutData | BulkOutState::OutDma => {
++                        internal_err!("Unexpected state: {:?}", out_state);
+                     }
+                 }
+                 // Indicate that the endpoint now has data available.
+                 self.descriptors[ep].state.set(EndpointState::Bulk(
+                     transfer_type,
+-                    direction,
+-                    BulkState::OutData,
++                    in_state,
++                    Some(BulkOutState::OutData),
+                 ));
+             }
+         }
+@@ -1564,8 +1543,8 @@ impl<'a> Usbd<'a> {
+         let state = self.descriptors[endpoint].state.get().ctrl_state();
+         match state {
+             CtrlState::Init => {
+-                let ep_buf = &self.descriptors[endpoint].slice;
+-                let ep_buf = ep_buf.expect("No slice set for this descriptor");
++                let ep_buf = &self.descriptors[endpoint].slice_out;
++                let ep_buf = ep_buf.expect("No OUT slice set for this descriptor");
+                 if ep_buf.len() < 8 {
+                     panic!("EP0 DMA buffer length < 8");
+                 }
+@@ -1697,21 +1676,21 @@ impl<'a> Usbd<'a> {
+         let regs = &*self.registers;
+ 
+         self.client.map(|client| {
+-            let (transfer_type, direction, state) =
++            let (transfer_type, in_state, out_state) =
+                 self.descriptors[endpoint].state.get().bulk_state();
+-            assert_eq!(state, BulkState::Init);
++            assert_eq!(in_state, Some(BulkInState::Init));
+ 
+             let result = client.packet_in(transfer_type, endpoint);
+             debug_packets!("packet_in => {:?}", result);
+-            let newstate = match result {
++            let new_in_state = match result {
+                 hil::usb::InResult::Packet(size) => {
+                     self.start_dma_in(endpoint, size);
+-                    BulkState::InDma
++                    BulkInState::InDma
+                 }
+ 
+                 hil::usb::InResult::Delay => {
+                     // No packet to send now. Wait for a resume call from the client.
+-                    BulkState::Init
++                    BulkInState::Init
+                 }
+ 
+                 hil::usb::InResult::Error => {
+@@ -1720,14 +1699,14 @@ impl<'a> Usbd<'a> {
+                             + EndpointStall::IO::In
+                             + EndpointStall::STALL::Stall,
+                     );
+-                    BulkState::Init
++                    BulkInState::Init
+                 }
+             };
+ 
+             self.descriptors[endpoint].state.set(EndpointState::Bulk(
+                 transfer_type,
+-                direction,
+-                newstate,
++                Some(new_in_state),
++                out_state,
+             ));
+         });
+     }
+@@ -1735,15 +1714,16 @@ impl<'a> Usbd<'a> {
+     fn transmit_out(&self, endpoint: usize) {
+         debug_events!("transmit_out({})", endpoint);
+ 
+-        let (transfer_type, direction, state) = self.descriptors[endpoint].state.get().bulk_state();
++        let (transfer_type, in_state, out_state) =
++            self.descriptors[endpoint].state.get().bulk_state();
+         // Starting the DMA can only happen in the OutData state, i.e. after an EPDATA event.
+-        assert_eq!(state, BulkState::OutData);
++        assert_eq!(out_state, Some(BulkOutState::OutData));
+         self.start_dma_out(endpoint);
+ 
+         self.descriptors[endpoint].state.set(EndpointState::Bulk(
+             transfer_type,
+-            direction,
+-            BulkState::OutDma,
++            in_state,
++            Some(BulkOutState::OutDma),
+         ));
+     }
+ 
+@@ -1751,9 +1731,9 @@ impl<'a> Usbd<'a> {
+         let regs = &*self.registers;
+ 
+         let slice = self.descriptors[endpoint]
+-            .slice
+-            .expect("No slice set for this descriptor");
+-        self.debug_packet("in", size, endpoint);
++            .slice_in
++            .expect("No IN slice set for this descriptor");
++        self.debug_in_packet(size, endpoint);
+ 
+         // Start DMA transfer
+         self.set_pending_dma();
+@@ -1766,8 +1746,8 @@ impl<'a> Usbd<'a> {
+         let regs = &*self.registers;
+ 
+         let slice = self.descriptors[endpoint]
+-            .slice
+-            .expect("No slice set for this descriptor");
++            .slice_out
++            .expect("No OUT slice set for this descriptor");
+ 
+         // Start DMA transfer
+         self.set_pending_dma();
+@@ -1777,10 +1757,27 @@ impl<'a> Usbd<'a> {
+     }
+ 
+     // Debug-only function
+-    fn debug_packet(&self, _title: &str, size: usize, endpoint: usize) {
++    fn debug_in_packet(&self, size: usize, endpoint: usize) {
++        let slice = self.descriptors[endpoint]
++            .slice_in
++            .expect("No IN slice set for this descriptor");
++        if size > slice.len() {
++            panic!("Packet is too large: {}", size);
++        }
++
++        let mut packet_hex = [0; 128];
++        packet_to_hex(slice, &mut packet_hex);
++        debug_packets!(
++            "in={}",
++            core::str::from_utf8(&packet_hex[..(2 * size)]).unwrap()
++        );
++    }
++
++    // Debug-only function
++    fn debug_out_packet(&self, size: usize, endpoint: usize) {
+         let slice = self.descriptors[endpoint]
+-            .slice
+-            .expect("No slice set for this descriptor");
++            .slice_out
++            .expect("No OUT slice set for this descriptor");
+         if size > slice.len() {
+             panic!("Packet is too large: {}", size);
+         }
+@@ -1788,8 +1785,7 @@ impl<'a> Usbd<'a> {
+         let mut packet_hex = [0; 128];
+         packet_to_hex(slice, &mut packet_hex);
+         debug_packets!(
+-            "{}={}",
+-            _title,
++            "out={}",
+             core::str::from_utf8(&packet_hex[..(2 * size)]).unwrap()
+         );
+     }
+@@ -1807,17 +1803,41 @@ impl<'a> power::PowerClient for Usbd<'a> {
+ }
+ 
+ impl<'a> hil::usb::UsbController<'a> for Usbd<'a> {
+-    fn endpoint_set_buffer(&self, endpoint: usize, buf: &'a [VolatileCell<u8>]) {
++    fn endpoint_set_ctrl_buffer(&self, buf: &'a [VolatileCell<u8>]) {
++        if buf.len() < 8 {
++            panic!("Endpoint buffer must be at least 8 bytes");
++        }
++        if !buf.len().is_power_of_two() {
++            panic!("Buffer size must be a power of 2");
++        }
++        self.descriptors[0].slice_in.set(buf);
++        self.descriptors[0].slice_out.set(buf);
++    }
++
++    fn endpoint_set_in_buffer(&self, endpoint: usize, buf: &'a [VolatileCell<u8>]) {
++        if buf.len() < 8 {
++            panic!("Endpoint buffer must be at least 8 bytes");
++        }
++        if !buf.len().is_power_of_two() {
++            panic!("Buffer size must be a power of 2");
++        }
++        if endpoint == 0 || endpoint >= NUM_ENDPOINTS {
++            panic!("Endpoint number is invalid");
++        }
++        self.descriptors[endpoint].slice_in.set(buf);
++    }
++
++    fn endpoint_set_out_buffer(&self, endpoint: usize, buf: &'a [VolatileCell<u8>]) {
+         if buf.len() < 8 {
+             panic!("Endpoint buffer must be at least 8 bytes");
+         }
+         if !buf.len().is_power_of_two() {
+             panic!("Buffer size must be a power of 2");
+         }
+-        if endpoint >= NUM_ENDPOINTS {
+-            panic!("Endpoint number is too high");
++        if endpoint == 0 || endpoint >= NUM_ENDPOINTS {
++            panic!("Endpoint number is invalid");
+         }
+-        self.descriptors[endpoint].slice.set(buf);
++        self.descriptors[endpoint].slice_out.set(buf);
+     }
+ 
+     fn enable_as_device(&self, speed: hil::usb::DeviceSpeed) {
+@@ -1900,8 +1920,8 @@ impl<'a> hil::usb::UsbController<'a> for Usbd<'a> {
+     fn endpoint_resume_in(&self, endpoint: usize) {
+         debug_events!("endpoint_resume_in({})", endpoint);
+ 
+-        let (_, direction, _) = self.descriptors[endpoint].state.get().bulk_state();
+-        assert!(direction.has_in());
++        let (_, in_state, _) = self.descriptors[endpoint].state.get().bulk_state();
++        assert!(in_state.is_some());
+ 
+         if self.dma_pending.get() {
+             debug_events!("requesting resume_in[{}]", endpoint);
+@@ -1916,20 +1936,21 @@ impl<'a> hil::usb::UsbController<'a> for Usbd<'a> {
+     fn endpoint_resume_out(&self, endpoint: usize) {
+         debug_events!("endpoint_resume_out({})", endpoint);
+ 
+-        let (transfer_type, direction, state) = self.descriptors[endpoint].state.get().bulk_state();
+-        assert!(direction.has_out());
++        let (transfer_type, in_state, out_state) =
++            self.descriptors[endpoint].state.get().bulk_state();
++        assert!(out_state.is_some());
+ 
+-        match state {
+-            BulkState::OutDelay => {
++        match out_state.unwrap() {
++            BulkOutState::OutDelay => {
+                 // The endpoint has now finished processing the last ENDEPOUT. No EPDATA event
+                 // happened in the meantime, so the state is now back to Init.
+                 self.descriptors[endpoint].state.set(EndpointState::Bulk(
+                     transfer_type,
+-                    direction,
+-                    BulkState::Init,
++                    in_state,
++                    Some(BulkOutState::Init),
+                 ));
+             }
+-            BulkState::OutData => {
++            BulkOutState::OutData => {
+                 // Although the client reported a delay before, an EPDATA event has
+                 // happened in the meantime. This pending transaction will now
+                 // continue in transmit_out().
+@@ -1942,25 +1963,11 @@ impl<'a> hil::usb::UsbController<'a> for Usbd<'a> {
+                     self.transmit_out(endpoint);
+                 }
+             }
+-            BulkState::Init | BulkState::OutDma | BulkState::InDma | BulkState::InData => {
+-                internal_err!("Unexpected state: {:?}", state);
++            BulkOutState::Init | BulkOutState::OutDma => {
++                internal_err!("Unexpected state: {:?}", out_state);
+             }
+         }
+     }
+-
+-    fn endpoint_cancel_in(&self, endpoint: usize) {
+-        debug_events!("endpoint_cancel_in({})", endpoint);
+-
+-        let (transfer_type, direction, state) = self.descriptors[endpoint].state.get().bulk_state();
+-        assert!(direction.has_in());
+-        assert_eq!(state, BulkState::InData);
+-
+-        self.descriptors[endpoint].state.set(EndpointState::Bulk(
+-            transfer_type,
+-            direction,
+-            BulkState::Init,
+-        ));
+-    }
+ }
+ 
+ fn status_epin(ep: usize) -> Field<u32, EndpointStatus::Register> {
+diff --git a/chips/sam4l/src/usbc/mod.rs b/chips/sam4l/src/usbc/mod.rs
+index 28a0b9f9..ab5b636f 100644
+--- a/chips/sam4l/src/usbc/mod.rs
++++ b/chips/sam4l/src/usbc/mod.rs
+@@ -1438,11 +1438,28 @@ fn endpoint_enable_interrupts(endpoint: usize, mask: FieldValue<u32, EndpointCon
+ }
+ 
+ impl hil::usb::UsbController<'a> for Usbc<'a> {
+-    fn endpoint_set_buffer(&self, endpoint: usize, buf: &'a [VolatileCell<u8>]) {
++    fn endpoint_set_ctrl_buffer(&self, buf: &'a [VolatileCell<u8>]) {
+         if buf.len() != 8 {
+             client_err!("Bad endpoint buffer size");
+         }
+ 
++        self._endpoint_bank_set_buffer(EndpointIndex::new(0), BankIndex::Bank0, buf);
++    }
++
++    fn endpoint_set_in_buffer(&self, endpoint: usize, buf: &'a [VolatileCell<u8>]) {
++        if buf.len() != 8 {
++            client_err!("Bad endpoint buffer size");
++        }
++
++        self._endpoint_bank_set_buffer(EndpointIndex::new(endpoint), BankIndex::Bank0, buf);
++    }
++
++    fn endpoint_set_out_buffer(&self, endpoint: usize, buf: &'a [VolatileCell<u8>]) {
++        if buf.len() != 8 {
++            client_err!("Bad endpoint buffer size");
++        }
++
++        // XXX: when implementing in_out endpoints, this should probably set a different slice than endpoint_set_in_buffer.
+         self._endpoint_bank_set_buffer(EndpointIndex::new(endpoint), BankIndex::Bank0, buf);
+     }
+ 
+@@ -1547,10 +1564,6 @@ impl hil::usb::UsbController<'a> for Usbc<'a> {
+         requests.resume_out = true;
+         self.requests[endpoint].set(requests);
+     }
+-
+-    fn endpoint_cancel_in(&self, _endpoint: usize) {
+-        unimplemented!()
+-    }
+ }
+ 
+ /// Static state to manage the USBC
+diff --git a/kernel/src/hil/usb.rs b/kernel/src/hil/usb.rs
+index 64610fa5..a114b30d 100644
+--- a/kernel/src/hil/usb.rs
++++ b/kernel/src/hil/usb.rs
+@@ -5,7 +5,9 @@ use crate::common::cells::VolatileCell;
+ /// USB controller interface
+ pub trait UsbController<'a> {
+     // Should be called before `enable_as_device()`
+-    fn endpoint_set_buffer(&self, endpoint: usize, buf: &'a [VolatileCell<u8>]);
++    fn endpoint_set_ctrl_buffer(&self, buf: &'a [VolatileCell<u8>]);
++    fn endpoint_set_in_buffer(&self, endpoint: usize, buf: &'a [VolatileCell<u8>]);
++    fn endpoint_set_out_buffer(&self, endpoint: usize, buf: &'a [VolatileCell<u8>]);
+ 
+     // Must be called before `attach()`
+     fn enable_as_device(&self, speed: DeviceSpeed);
+@@ -27,8 +29,6 @@ pub trait UsbController<'a> {
+     fn endpoint_resume_in(&self, endpoint: usize);
+ 
+     fn endpoint_resume_out(&self, endpoint: usize);
+-
+-    fn endpoint_cancel_in(&self, endpoint: usize);
+ }
+ 
+ #[derive(Clone, Copy, Debug)]
-- 
GitLab