ad-spi/host/main.cpp
2023-12-08 00:11:50 +01:00

331 lines
11 KiB
C++

#include <bits/fs_fwd.h>
#include <charconv>
#include <chrono>
#include <fstream>
#include <stdexcept>
#include <string>
#include <thread>
#include <iostream>
#include <format>
#include <csignal>
#include <filesystem>
#include <serial/serial.h>
#include "argparse/argparse.hpp"
#define NO_ARDUINO
#include "../arduino/common.hpp"
#include <gz-util/string/conversion.hpp>
static bool stopRequested = false;
// TODO: remove when c++23 is used
namespace std {
template <typename... Args>
inline void println(const std::format_string<Args...> fmt, Args&&... args) {
std::cout << std::vformat(fmt.get(), std::make_format_args(args...)) << std::endl;
}
}
/**
* @brief Wrapper for numbers, so that the default get from argparse isnt used for ints, since it does not handle 0x and 0b prefixes
*/
/* template<std::integral T> */
/* class NumberWrapper { */
/* public: */
/* NumberWrapper() : t(0) {}; */
/* NumberWrapper(const std::string& s) { */
/* if (s.size() >= 2) { */
/* if (s.at(1) == 'x') { */
/* t = gz::fromHexString<T>(s); */
/* return; */
/* } */
/* else if (s.at(1) == 'b') { */
/* t = gz::fromBinString<T>(s); */
/* return; */
/* } */
/* else if (s.at(1) == 'o') { */
/* t = gz::fromOctString<T>(s); */
/* return; */
/* } */
/* } */
/* t = gz::fromString<T>(s); */
/* } */
/* operator T() const { return t; } */
/* private: */
/* T t; */
/* }; */
/* // overload the argparse::get, which is used to convert the strings */
/* namespace argparse { */
/* template<typename T> */
/* inline NumberWrapper<T> get(const std::string& v) { return NumberWrapper<T>(v); }; */
/* } */
// formater for ControlBytes enum
template<>
struct std::formatter<ControlBytes> : std::formatter<std::string> {
template<class FormatContext>
auto format(ControlBytes c, FormatContext& fc) const {
return std::formatter<std::string>::format(ControlBytesString(c), fc);
}
};
class ArduinoException : public std::exception {
public:
ArduinoException(const std::string& message) : message(std::format("Unexpected Arduino behaviour: '{}'", message)) {}
virtual const char* what() const noexcept {
return message.c_str();
}
private:
const std::string message;
};
class ConnectionException : public std::exception {
public:
ConnectionException(const std::string& message) : message(std::format("Connection error: '{}'", message)) {}
virtual const char* what() const noexcept {
return message.c_str();
}
private:
const std::string message;
};
/**
* @brief Read data from a file and print it
*/
void signalHandler(int signal) {
std::println("Caught signal {}, exiting.", signal);
stopRequested = true;
}
/**
* @brief Wait for a response from the Arduino
* @details
* The response must always start with a ControlByte.
* If the ControlByte is invalid, <retry> more bytes are tried.
* If the ControlByte is PRINT, the message following the byte is printed and PRINT is returned.
* Else returnes the ControlByte
*/
ControlBytes waitArduino(serial::Serial& s, const std::string& fname) {
uint8_t ctrl;
while (!stopRequested) {
if (s.read(&ctrl, 1) == 0)
throw ConnectionException(std::format("{}: did not receive answer from Arduino", fname));
/* std::println("waitArduino: Received ctrl: {}", static_cast<ControlBytes>(ctrl)); */
switch (ctrl) {
case ControlBytes::PRINT: {
std::println("Arduino: {}", s.readline());
return ControlBytes::PRINT;
}
case ControlBytes::READ:
case ControlBytes::WRITE:
case ControlBytes::READY:
case ControlBytes::MEM_256KB:
case ControlBytes::MEM_2M:
case ControlBytes::SET_ADDRESS: {
return static_cast<ControlBytes>(ctrl);
break;
}
default: {
std::println("waitArduino: Received invalid ControlByte: '{}'", ctrl);
return ControlBytes::MAX_ENUM;
}
}
}
return ControlBytes::MAX_ENUM;
}
void sendControlByte(serial::Serial& s, ControlBytes ctrl, const std::string& fname) {
uint8_t ctrl8 = static_cast<uint8_t>(ctrl);
if (s.write(&ctrl8, 1) != 1) {
throw ConnectionException(std::format("{}: could not send {}", fname, ControlBytesString(ctrl)));
}
}
void receiveControlByte(serial::Serial& s, ControlBytes ctrl, const std::string& fname) {
ControlBytes receivedCtrl = waitArduino(s, fname);
if (receivedCtrl != ctrl) {
std::string err = std::format("{}: did not receive {}", fname, ControlBytesString(ctrl));
// SerialException saves the string, no ptr problems here
throw ConnectionException(err.data());
}
}
/**
* @brief Send a ready? command
* @details
* 1) Send `READY`
* 2) Receive `READY` or throw ArduinoException
*/
void getReady(serial::Serial& s) {
sendControlByte(s, ControlBytes::READY, "getReady");
receiveControlByte(s, ControlBytes::READY, "getReady");
}
/**
* @brief Send a write command and buffer
* @details
* 1) Send `WRITE` - `buffer size` - `buffer` - `WRITE`
* 2) Receive `WRITE` or throw ArduinoException
*/
void write(serial::Serial& s, std::vector<uint8_t> buffer) {
sendControlByte(s, ControlBytes::WRITE, "write(1)");
uint8_t cmd;
for (unsigned i = 0; i < sizeof(buffer_t); i++) {
cmd = (buffer.size() >> i * 8);
if (s.write(&cmd, 1) != 1)
throw ConnectionException("write: Could not send buffer size");
}
if (s.write(buffer) != buffer.size())
throw ConnectionException("write: Could not send buffer");
sendControlByte(s, ControlBytes::WRITE, "write(2)");
receiveControlByte(s, ControlBytes::WRITE, "write");
}
/**
* @brief Send a read command and receive the buffer
* @details
* 1) Send `READ` - `buffer size` - `READ`
* 2) Receive: `READ` - `buffer size` - `buffer` - `READ` or throw ArduinoException
*/
void read(serial::Serial& s, std::vector<uint8_t>& buffer, buffer_t bufferSize) {
sendControlByte(s, ControlBytes::READ, "read(1)");
uint8_t cmd;
std::println("read: bufferSize={}", bufferSize);
for (unsigned i = 0; i < sizeof(buffer_t); i++) {
cmd = (bufferSize >> i * 8);
/* std::println("read: sending={}", cmd); */
if (s.write(&cmd, 1) != 1)
throw ConnectionException("read: Could not send buffer size");
}
sendControlByte(s, ControlBytes::READ, "read(2)");
receiveControlByte(s, ControlBytes::READ, "read(1)");
buffer_t announcedBufferSize = 0;
for (unsigned i = 0; i < sizeof(buffer_t); i++) {
if (s.read(&cmd, 1) == 0)
throw ArduinoException(std::format("read: Could not receive buffer size"));
/* std::println("read: received={}", cmd); */
announcedBufferSize |= (cmd << i * 8);
}
if (announcedBufferSize != bufferSize)
throw ArduinoException(std::format("read: bufferSize={:#0x}, but announcedBufferSize={:#0x}", bufferSize, announcedBufferSize));
buffer.clear();
size_t receivedBufferSize = s.read(buffer, bufferSize);
if (receivedBufferSize != bufferSize)
throw ArduinoException(std::format("read: bufferSize={:#0x}, but receivedBufferSize={:#0x}", bufferSize, receivedBufferSize));
std::println("read: Received buffer size {}", receivedBufferSize);
receiveControlByte(s, ControlBytes::READ, "read(2)");
}
void validate(const std::vector<uint8_t>& correctData, const std::vector<uint8_t>& readData, address_t startAddress, bool fileStartsFromZero=false) {
size_t correctDataSize = fileStartsFromZero ? correctData.size() - startAddress : correctData.size();
unsigned errors = 0;
if (correctDataSize != readData.size()) {
std::println("validate: Buffers have different sizes. correctData.size={}, readData.size={}", correctDataSize, readData.size());
return;
}
for (unsigned i = 0; i < correctData.size(); i++) {
uint8_t correct = correctData.at(i);
uint8_t actual = readData.at(i);
if (correct != actual) {
std::println("validate: address=[0x{:04x} 0b{:015b}]: correct={{0x{:02x} 0b{:08b}}} - actual={{0x{:02x} 0b{:08b}}}", startAddress+i, startAddress+i, correct, correct, actual, actual);
errors++;
}
}
std::println("validate: {: 4} mismatches found", errors);
}
void readFile(const std::string& path, std::vector<uint8_t>& buffer) {
std::ifstream file(path, std::ios::binary);
// TODO: check if required
// Stop eating new lines in binary mode!!!
file.unsetf(std::ios::skipws);
std::streampos fileSize;
file.seekg(0, std::ios::end);
fileSize = file.tellg();
file.seekg(0, std::ios::beg);
buffer.clear();
buffer.reserve(fileSize);
// read the data:
buffer.insert(buffer.begin(), std::istream_iterator<uint8_t>(file), std::istream_iterator<uint8_t>());
}
struct Arguments : public argparse::Args {
std::string& filename = kwarg("f,file", "path to a binary file");
bool& verbose = flag("verbose", "Also print successful operations");
/* bool& write = flag("w,write", "Write data"); */
/* bool& read = flag("r,read", "Read data"); */
// timeout needs to be high enough for the entire write cycle
unsigned& timeout = kwarg("timeout", "Timeout in ms").set_default(5000);
std::string& device = kwarg("device", "Path to the serial device (Arduino)").set_default("/dev/ttyACM0");
/* virtual void welcome() { */
/* std::println("EEEPROM-programmer"); */
/* } */
};
namespace fs = std::filesystem;
int main(int argc, const char** argv) {
auto args = argparse::parse<Arguments>(argc, argv);
if (!fs::exists(fs::path(args.device))) {
std::println("Error: device '{}' not found", args.device);
return 1;
}
if (!fs::exists(fs::path(args.device))) {
std::println("Error: device '{}' not found", args.device);
return 1;
}
if (!fs::exists(fs::path(args.filename))) {
std::println("Error: file '{}' not found", args.filename);
return 1;
}
// read file
std::vector<uint8_t> binFromFile{};
if (!args.filename.empty()) {
readFile(args.filename, binFromFile);
}
const uint32_t baud_rate = 9600;
serial::Serial s(args.device, baud_rate);
auto timeout = serial::Timeout::simpleTimeout(args.timeout);
s.setTimeout(timeout);
s.flush();
std::this_thread::sleep_for(std::chrono::milliseconds(200));
try {
getReady(s);
write(s, binFromFile);
}
catch (const ArduinoException& e) {
std::println("ArduinoException: {}", e.what());
for (const auto& s : s.readlines()) {
std::println("From Arduino: '{}'", s);
}
}
s.close();
}