#include "WinUSBDevice.h"

#include <USBDescriptor.h>

#include <cstring>

using std::memcpy;

// Microsoft OS String Descriptor
#define STRING_OFFSET_MSOS (0xEE)

// Vendor-defined request code to get the compatible IDs. I set it to 33 to aid debugging. The actual value doesn't matter.
#define GET_MS_DESCRIPTORS (33)


#define GENRE_INDEX (1)
#define COMPAT_ID_INDEX (4)
#define EXTENDED_PROPERTIES_INDEX (5)

#define USB_DT_STRING (3) // USB 2.0 spec table 9.5; Linux header ch9.h

#ifndef nullptr
#define nullptr NULL
#endif

// Copy a string, *not* including its null terminator to the destination, up to a maximum of maxBytes.
void copyString8(uint8_t* dest, int maxBytes, const char* src)
{
    for (int i = 0; i < maxBytes; ++i)
    {
        if (src[i] == nullptr)
            return;
        dest[i] = src[i];
    }
}

// Copy a string, *not* including its null terminator to the destination, up to a maximum of maxBytes, but in UTF-16.
void copyString16(uint8_t* dest, int maxBytes, const char* src)
{
    for (int i = 0; i < maxBytes; ++i)
    {
        if (i % 2 == 0)
        {
            if (src[i/2] == nullptr)
                return;
            dest[i] = src[i/2];
        }
        else
        {
            dest[i] = 0;
        }
    }
}

// Print information about a USB setup transfer.
void debugUsbSetup(const CONTROL_TRANSFER& transfer)
{
    printf("\n");
    
    switch (transfer.setup.bmRequestType.Recipient)
    {
    case DEVICE_RECIPIENT:
        printf("Recipient: Device\n");
        break;
    case INTERFACE_RECIPIENT:
        printf("Recipient: Interface\n");
        break;
    case ENDPOINT_RECIPIENT:
        printf("Recipient: Endpoint\n");
        break;
    case OTHER_RECIPIENT:
        printf("Recipient: Other\n");
        break;
    default:
        printf("Recipient: Unknown\n");
        break;
    }
    
    printf("Length: %d\n", transfer.setup.wLength);
    
    switch (transfer.setup.bmRequestType.Type)
    {
    case STANDARD_TYPE:
        printf("Request type: Standard\nRequest: ");
        switch (transfer.setup.bRequest)
        {
        case GET_STATUS:
            printf("GET_STATUS\n");
            break;
        case CLEAR_FEATURE:
            printf("CLEAR_FEATURE\n");
            break;
        case SET_FEATURE:
            printf("SET_FEATURE\n");
            break;
        case SET_ADDRESS:
            printf("SET_ADDRESS\n");
            break;
        case GET_DESCRIPTOR:
            printf("GET_DESCRIPTOR\nDescriptor Type: ");
            switch (DESCRIPTOR_TYPE(transfer.setup.wValue))
            {
            case DEVICE_DESCRIPTOR:
                printf("DEVICE_DESCRIPTOR\n");
                break;
            case CONFIGURATION_DESCRIPTOR:
                printf("CONFIGURATION_DESCRIPTOR\n");
                break;
            case STRING_DESCRIPTOR:
                printf("STRING_DESCRIPTOR\nDescriptor Index: ");
                switch (DESCRIPTOR_INDEX(transfer.setup.wValue))
                {
                case STRING_OFFSET_LANGID:
                    printf("STRING_OFFSET_LANGID\n");
                    break;
                case STRING_OFFSET_IMANUFACTURER:
                    printf("STRING_OFFSET_IMANUFACTURER\n");
                    break;
                case STRING_OFFSET_IPRODUCT:
                    printf("STRING_OFFSET_IPRODUCT\n");
                    break;
                case STRING_OFFSET_ISERIAL:
                    printf("STRING_OFFSET_ISERIAL\n");
                    break;
                case STRING_OFFSET_ICONFIGURATION:
                    printf("STRING_OFFSET_ICONFIGURATION\n");
                    break;
                case STRING_OFFSET_IINTERFACE:
                    printf("STRING_OFFSET_IINTERFACE\n");
                    break;
                case STRING_OFFSET_MSOS:
                    // This is a Microsoft extension. I think it's a reasonable one before you get all anti-MS.
                    printf("STRING_OFFSET_MSOS\n"); 
                    break;
                default:
                    printf("Index 0x%02hhx\n", DESCRIPTOR_INDEX(transfer.setup.wValue));
                    break;
                }
                break;
            case INTERFACE_DESCRIPTOR:
                printf("INTERFACE_DESCRIPTOR\n");
                break;
            case ENDPOINT_DESCRIPTOR:
                printf("ENDPOINT_DESCRIPTOR\n");
                break;
            case QUALIFIER_DESCRIPTOR:
                printf("QUALIFIER_DESCRIPTOR\n");
                break;
            default:
                printf("Descriptor 0x%02hhx\n", DESCRIPTOR_TYPE(transfer.setup.wValue));
                break;
            }
            break;
        case SET_DESCRIPTOR:
            printf("SET_DESCRIPTOR\n");
            break;
        case GET_CONFIGURATION:
            printf("GET_CONFIGURATION\n");
            break;
        case SET_CONFIGURATION:
            printf("SET_CONFIGURATION\n");
            break;
        case GET_INTERFACE:
            printf("GET_INTERFACE\n");
            break;
        case SET_INTERFACE:
            printf("SET_INTERFACE\n");
            break;
        case 12: //SYNCH_FRAME:
            printf("SYNCH_FRAME\n");
            break;
        default:
            printf("Request 0x%02hhx\n", transfer.setup.bRequest);
        }
        break;
    case CLASS_TYPE:
        printf("Request type: Class Request 0x%02hhx\n", transfer.setup.bRequest);
        break;
    case VENDOR_TYPE:
        printf("Request type: Vendor\nRequest: ");
        switch (transfer.setup.bRequest)
        {
        // Note that GET_MS_DESCRIPTORS is an arbitrary value, which is returned in the STRING_OFFSET_MSOS string descriptor.
        case GET_MS_DESCRIPTORS:
            printf("GET_MS_DESCRIPTORS\nIndex: ");
            switch (transfer.setup.wIndex)
            {
            case COMPAT_ID_INDEX:
                printf("COMPAT_ID_INDEX\n");
                break;
            case EXTENDED_PROPERTIES_INDEX:
                printf("EXTENDED_PROPERTIES_INDEX\n");
                break;
            case GENRE_INDEX:
                printf("GENRE_INDEX\n");
                break;
            default:
                printf("0x%02hhx\n", transfer.setup.wIndex);
                break;
            }
            break;
        default:
            printf("Request 0x%02hhx\n", transfer.setup.bRequest);
            break;
        }
        break;
    case RESERVED_TYPE:
        printf("Request type: Reserved Request 0x%02hhx\n", transfer.setup.bRequest);
        break;
    default:
        printf("Request type: Unknown Request 0x%02hhx\n", transfer.setup.bRequest);
        break;
    }    
}


WinUSBDevice::WinUSBDevice(uint16_t vendor_id, uint16_t product_id, uint16_t product_release, const char* guid) : USBDevice(vendor_id, product_id, product_release)
{
    // Initialise the descriptors we'll return.
    memset(&osStringDesc, 0, sizeof(osStringDesc));

    osStringDesc.bLength = sizeof(osStringDesc);
    osStringDesc.bDescriptorType = USB_DT_STRING;
    copyString16(osStringDesc.qwSignature, sizeof(osStringDesc.qwSignature), "MSFT100");
    osStringDesc.bMS_VendorCode = GET_MS_DESCRIPTORS;

    memset(&compatIdData, 0, sizeof(compatIdData));

    compatIdData.header.dwLength = sizeof(compatIdData); // Length of the complete Extended Compat header (and contents?)
    compatIdData.header.bcdVersion = 0x0100;
    compatIdData.header.wIndex = COMPAT_ID_INDEX;
    compatIdData.header.bCount = 1;
    compatIdData.header.reserved1 = 0x01;

    compatIdData.data[0].bFirstInterfaceNumber = 0;
    copyString8(compatIdData.data[0].compatibleID, sizeof(compatIdData.data[0].compatibleID), "WINUSB");

    memset(&extendedPropertyData, 0, sizeof(extendedPropertyData));

    extendedPropertyData.header.dwLength = sizeof(extendedPropertyData); // Length of the complete Extended Compat header (and contents)
    extendedPropertyData.header.bcdVersion = 0x0100; // This is little endian apparently.
    extendedPropertyData.header.wIndex = EXTENDED_PROPERTIES_INDEX;
    extendedPropertyData.header.wCount = 1;

    extendedPropertyData.data[0].dwSize = sizeof(extendedPropertyData.data[0]);
    extendedPropertyData.data[0].dwPropertyDataType = 7; // 1 means NULL terminated unicode string. 7 means a list of strings (for the GUID*s* variant).
    extendedPropertyData.data[0].wPropertyNameLength = sizeof(extendedPropertyData.data[0].bPropertyName);
    copyString16(extendedPropertyData.data[0].bPropertyName, sizeof(extendedPropertyData.data[0].bPropertyName), "DeviceInterfaceGUIDs");
    extendedPropertyData.data[0].dwPropertyDataLength = sizeof(extendedPropertyData.data[0].bPropertyData);
    copyString16(extendedPropertyData.data[0].bPropertyData, sizeof(extendedPropertyData.data[0].bPropertyData), guid);
}

bool WinUSBDevice::USBCallback_request()
{
    // This can never be null.
    CONTROL_TRANSFER& transfer = *getTransferPtr();
        
    debugUsbSetup(transfer);
    
    // EXTENDED_PROPERTIES_INDEX is sent to the interface, not the device so we can't universally reject non-device requests.
    // if (transfer.setup.bmRequestType.Recipient != DEVICE_RECIPIENT)
    //     return false;

    // Intercept the request for string descriptor 0xEE
    if (transfer.setup.bmRequestType.Type == STANDARD_TYPE &&
        transfer.setup.bRequest == GET_DESCRIPTOR &&
        DESCRIPTOR_TYPE(transfer.setup.wValue) == STRING_DESCRIPTOR &&
        DESCRIPTOR_INDEX(transfer.setup.wValue) == STRING_OFFSET_MSOS)
    {
        transfer.ptr = reinterpret_cast<uint8_t*>(&osStringDesc);
        transfer.remaining = osStringDesc.bLength;
        transfer.direction = DEVICE_TO_HOST;
        return true;
    }

    // Our custom descriptor.
    if (transfer.setup.bmRequestType.Type == VENDOR_TYPE &&
        transfer.setup.bRequest == GET_MS_DESCRIPTORS)
    {
        switch (transfer.setup.wIndex)
        {
        case COMPAT_ID_INDEX: // Extended Compat ID
            transfer.ptr = reinterpret_cast<uint8_t*>(&compatIdData);
            transfer.remaining = sizeof(compatIdData);
            transfer.direction = DEVICE_TO_HOST;
            return true;
        case EXTENDED_PROPERTIES_INDEX: // Extended Properties.
            transfer.ptr = reinterpret_cast<uint8_t*>(&extendedPropertyData);
            transfer.remaining = sizeof(extendedPropertyData);
            transfer.direction = DEVICE_TO_HOST;
            return true;
        case GENRE_INDEX: // Genre
        default:
            return false;
        }
    }


    return false;
}