blob: 54904d287aea014db19437b74ea50045d0ab9602 [file] [log] [blame]
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include <memory>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/flags/usage.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "sw/host/tests/usbdev/usbdev_stream/usb_device.h"
#include "sw/host/tests/usbdev/usbdev_stream/usbdev_int.h"
ABSL_FLAG(std::string, file, "", "File to transfer to remote device.");
ABSL_FLAG(bool, verbose, false, "Enable verbose logging");
// Compared to the device side, these values are one lower.
// The USBDevStream derived classes increment the endpoint value
// by 1 automatically.
constexpr unsigned kDataEp = 0;
constexpr unsigned kControlEp = 1;
struct ControlPkt {
uint32_t size;
};
int main(int argc, char** argv) {
absl::SetProgramUsageMessage("Matcha USB data loader");
auto args = absl::ParseCommandLine(argc, argv);
argc = args.size();
argv = &args[0];
if (absl::GetFlag(FLAGS_file) == "") {
LOG(ERROR) << "--file is required!";
return -1;
}
USBDevice dev(absl::GetFlag(FLAGS_verbose));
CHECK(dev.Init(0x18d1, 0x503a, 0, 0));
CHECK(dev.Open());
uint8_t* data;
int fd = open(absl::GetFlag(FLAGS_file).c_str(), 0);
CHECK(fd > 0);
struct stat sb;
CHECK(fstat(fd, &sb) == 0);
data = (uint8_t*)mmap(nullptr, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
CHECK(data != MAP_FAILED);
close(fd);
uint32_t transfer_bytes = sb.st_size;
std::unique_ptr<USBDevInt> data_ep, control_ep;
data_ep = std::make_unique<USBDevInt>(
&dev, /*bulk=*/true, /*idx=*/kDataEp, transfer_bytes,
/*retrieve=*/false, /*check=*/false, /*send=*/true,
absl::GetFlag(FLAGS_verbose));
CHECK(data_ep->Open(/*interface=*/0));
control_ep = std::make_unique<USBDevInt>(
&dev, /*bulk=*/true, /*idx=*/kControlEp, sizeof(ControlPkt),
/*retrieve=*/false, /*check=*/false, /*send=*/true,
absl::GetFlag(FLAGS_verbose));
CHECK(control_ep->Open(/*interface=*/1));
ControlPkt control;
control.size = transfer_bytes;
CHECK(control_ep->SpaceAvailable(nullptr) >= sizeof(ControlPkt));
CHECK(control_ep->AddData(
const_cast<const uint8_t*>(reinterpret_cast<uint8_t*>(&control)),
sizeof(control)));
do {
switch (dev.CurrentState()) {
case USBDevice::StateStreaming:
CHECK(control_ep->Service());
break;
default:
CHECK(false);
}
dev.Service();
} while (control_ep->BytesSent() < sizeof(ControlPkt));
uint32_t bytes_sent = 0;
do {
uint32_t avail = data_ep->SpaceAvailable(nullptr);
uint32_t bytes_to_queue = std::min(avail, transfer_bytes - bytes_sent);
CHECK(data_ep->AddData(data + bytes_sent, bytes_to_queue));
if (absl::GetFlag(FLAGS_verbose)) {
LOG(INFO) << "Queued " << bytes_to_queue << " bytes of data.";
}
bool done = false;
do {
switch (dev.CurrentState()) {
case USBDevice::StateStreaming:
CHECK(data_ep->Service());
if (data_ep->BytesSent() >= bytes_sent + bytes_to_queue) {
done = true;
}
break;
default:
CHECK(false);
}
dev.Service();
} while (!done);
bytes_sent = data_ep->BytesSent();
if (absl::GetFlag(FLAGS_verbose)) {
LOG(INFO) << "Sent " << bytes_to_queue << " (total: " << bytes_sent
<< ") bytes.";
}
} while (bytes_sent < transfer_bytes);
munmap(data, transfer_bytes);
if (dev.Fin()) {
return 0;
} else {
return 1;
}
}