// UsbFlashDrive.cpp 2013/1/25
#include "mbed.h"
#include "rtos.h"
#include "BaseUsbHost.h"
//#define DEBUG
#include "BaseUsbHostDebug.h"
#define TEST
#include "BaseUsbHostTest.h"
#include "UsbFlashDrive.h"

//#define WRITE_PROTECT


uint32_t BE32(uint8_t* d)
{
    return (d[0] << 24) | (d[1] << 16) | (d[2] << 8) | d[3];
}

void BE16(uint32_t n, uint8_t* d)
{
    d[0] = (uint8_t)(n >> 8);
    d[1] = (uint8_t)n;
}

void BE32(uint32_t n, uint8_t* d)
{
    d[0] = (uint8_t)(n >> 24);
    d[1] = (uint8_t)(n >> 16);
    d[2] = (uint8_t)(n >> 8);
    d[3] = (uint8_t)n;
}

UsbFlashDrive::UsbFlashDrive(const char* name, ControlEp* ctlEp): FATFileSystem(name)
{
    m_name = name;
    
    if (ctlEp == NULL) { // root hub
        DBG_OHCI(LPC_USB->HcRhPortStatus1);
        TEST_ASSERT_FALSE(LPC_USB->HcRhPortStatus1 & 0x200);
        ctlEp = new ControlEp();
        TEST_ASSERT_TRUE(ctlEp);
    }

    CTASSERT(sizeof(CBW) == 31);
    CTASSERT(sizeof(CSW) == 13);
    TEST_ASSERT(sizeof(CBW) == 31);
    TEST_ASSERT(sizeof(CSW) == 13);

    m_numBlocks = 0;
    m_BlockSize = 0;
    m_lun = 0;
    m_interface = 0;
    m_pEpBulkIn = NULL;
    m_pEpBulkOut = NULL;
    
    ParseConfiguration(ctlEp);
    int rc = ctlEp->SetConfiguration(1);
    TEST_ASSERT_EQUAL(rc, USB_OK);
    GetMaxLUN(ctlEp);
    setup(ctlEp);
}

bool UsbFlashDrive::check(ControlEp* ctlEp)
{
    if (ctlEp == NULL) {
        return false;
    }
    CTASSERT(sizeof(StandardDeviceDescriptor) == 18);
    CTASSERT(sizeof(StandardConfigurationDescriptor) == 9);
    CTASSERT(sizeof(StandardInterfaceDescriptor) == 9);
    TEST_ASSERT(sizeof(StandardDeviceDescriptor) == 18);
    TEST_ASSERT(sizeof(StandardConfigurationDescriptor) == 9);
    TEST_ASSERT(sizeof(StandardInterfaceDescriptor) == 9);

    StandardDeviceDescriptor desc;
    int rc = ctlEp->GetDescriptor(USB_DESCRIPTOR_TYPE_DEVICE, 0, reinterpret_cast<uint8_t*>(&desc), sizeof(StandardDeviceDescriptor));
    if (rc != USB_OK) {
        return false;
    }
    if (desc.bDeviceClass == 8) {
        return true;
    } else if (desc.bDeviceClass != 0x00) {
        return false;
    }
    uint8_t temp[4];
    rc = ctlEp->GetDescriptor(USB_DESCRIPTOR_TYPE_CONFIGURATION, 0, temp, sizeof(temp));
    if (rc != USB_OK) {
        return false;
    }
    StandardConfigurationDescriptor* cfg = reinterpret_cast<StandardConfigurationDescriptor*>(temp);
    uint8_t* buf = new uint8_t[cfg->wTotalLength];
    
    rc = ctlEp->GetDescriptor(USB_DESCRIPTOR_TYPE_CONFIGURATION, 0, buf, cfg->wTotalLength);
    if (rc != USB_OK) {
        return false;
    }
    DBG_HEX(buf, cfg->wTotalLength);
    bool ret = false;
    for(int pos = 0; pos < cfg->wTotalLength; pos += buf[pos]) {
        StandardInterfaceDescriptor* desc = reinterpret_cast<StandardInterfaceDescriptor*>(buf+pos);
        if (desc->bDescriptorType == 4) { // interface ?
            if (desc->bInterfaceClass == 8 && desc->bInterfaceSubClass == 6 && desc->bInterfaceProtocol == 0x50) {
                ret = true;
            }
            break;
        }
    }
    delete[] buf;
    return ret;
}

int UsbFlashDrive::disk_initialize()
{
    //DBG("m_BlockSize=%d\n", m_BlockSize);
    if (m_BlockSize != 512) {
        return 1;
    }
    return 0;    
}

int UsbFlashDrive::disk_write(const uint8_t* buffer, uint64_t sector)
{
    m_report_disk_write++;
    //DBG("buffer=%p block_number=%d\n", buffer, sector);
    int ret = MS_BulkSend(sector, 1, buffer);
    if (ret >= 0) {
        return 0;
    }
    return 1;
}

int UsbFlashDrive::disk_read(uint8_t* buffer, uint64_t sector)
{
    m_report_disk_read++;
    //DBG("buffer=%p block_number=%d\n", buffer, sector);
    int ret = MS_BulkRecv(sector, 1, buffer);
    if (ret >= 0) {
        return 0;
    }
    return 1;
}    

int UsbFlashDrive::disk_status()
{
    m_report_disk_status++;
    return 0;
}

int UsbFlashDrive::disk_sync()
{
    m_report_disk_sync++;
    return 0;
}

uint64_t UsbFlashDrive::disk_sectors()
{
    DBG("m_numBlocks=%d\n", m_numBlocks);
    return m_numBlocks;
}

int UsbFlashDrive::setup(ControlEp* ctlEp, int timeout)
{

    int retry = 0;
    Timer t;
    t.start();
    t.reset();
    while(t.read_ms() < timeout) {
        DBG("retry=%d t=%d\n", retry, t.read_ms());
        if (retry > 80) {
            return -1;
        }
        int rc = TestUnitReady();
        DBG("TestUnitReady(): %d\n", rc);
        if (rc == USB_OK) {
            DBG("m_CSW.bCSWStatus: %02X\n", m_CSW.bCSWStatus);
            if (m_CSW.bCSWStatus == 0x00) {
                break;
            }
        }
        GetSenseInfo();
        retry++;
        wait_ms(50);
    }
    if (t.read_ms() >= timeout) {
        return -1;
    }
    ReadCapacity();
    Inquire();
    return 0;
}

int UsbFlashDrive::ParseConfiguration(ControlEp* ctlEp)
{
    TEST_ASSERT(ctlEp);
    TEST_ASSERT(sizeof(StandardEndpointDescriptor) == 7);
    uint8_t temp[4];
    int rc = ctlEp->GetDescriptor(USB_DESCRIPTOR_TYPE_CONFIGURATION, 0, temp, sizeof(temp));
    if (rc != USB_OK) {
        return rc;
    }
    StandardConfigurationDescriptor* cfg = reinterpret_cast<StandardConfigurationDescriptor*>(temp);
    uint8_t* buf = new uint8_t[cfg->wTotalLength];
    rc = ctlEp->GetDescriptor(USB_DESCRIPTOR_TYPE_CONFIGURATION, 0, buf, cfg->wTotalLength);
    if (rc != USB_OK) {
        return rc;
    }
    DBG_HEX(buf, cfg->wTotalLength);
    for(int pos = 0; pos < cfg->wTotalLength; pos += buf[pos]) {
        StandardEndpointDescriptor* desc = reinterpret_cast<StandardEndpointDescriptor*>(buf+pos);
        if (desc->bDescriptorType == USB_DESCRIPTOR_TYPE_ENDPOINT) {
            if (desc->bmAttributes == 2) { // bulk
                BulkEp* pEp = new BulkEp(ctlEp->GetAddr(), desc->bEndpointAddress, desc->wMaxPacketSize);
                if (desc->bEndpointAddress & 0x80) {
                    m_pEpBulkIn = pEp;
                } else {
                    m_pEpBulkOut = pEp;
                }
            }
        }
    }
    delete[] buf;
    if (m_pEpBulkIn && m_pEpBulkOut) {
        return USB_OK;
    }
    return USB_ERROR;    
}

int UsbFlashDrive::BulkOnlyMassStorageReset(ControlEp* ctlEp)
{
    TEST_ASSERT(ctlEp);
    int rc = ctlEp->controlReceive(0x21, 0xff, 0x0000, m_interface, NULL, 0); 
    TEST_ASSERT(rc == USB_OK);
    return rc;
}

int UsbFlashDrive::GetMaxLUN(ControlEp* ctlEp)
{
    TEST_ASSERT(ctlEp);
    TEST_ASSERT(m_interface == 0);
    uint8_t temp[1];
    int rc = ctlEp->controlReceive(0xa1, 0xfe, 0x0000, m_interface, temp, sizeof(temp)); 
    TEST_ASSERT(rc == USB_OK);
    DBG_BYTES("GetMaxLUN", temp, sizeof(temp));
    m_MaxLUN = temp[0];
    TEST_ASSERT(m_MaxLUN <= 15);
    return rc;
}


int UsbFlashDrive::TestUnitReady()
{
    const uint8_t cdb[6] = {SCSI_CMD_TEST_UNIT_READY, 0x00, 0x00, 0x00, 0x00, 0x00};
    m_CBW.dCBWDataTraansferLength = 0;
    m_CBW.bmCBWFlags = 0x00;
    CommandTransport(cdb, sizeof(cdb));
    StatusTransport();
    return 0;
}

int UsbFlashDrive::GetSenseInfo()
{
    const uint8_t cdb[6] = {SCSI_CMD_REQUEST_SENSE, 0x00, 0x00, 0x00, 18, 0x00};
    m_CBW.dCBWDataTraansferLength = 18;
    m_CBW.bmCBWFlags = 0x80; // data In
    CommandTransport(cdb, sizeof(cdb));

    uint8_t buf[18];
    _bulkRecv(buf, sizeof(buf));
    DBG_HEX(buf, sizeof(buf));

    StatusTransport();
    TEST_ASSERT(m_CSW.bCSWStatus == 0x00);
    return 0;
}

int UsbFlashDrive::ReadCapacity()
{
    const uint8_t cdb[10] = {SCSI_CMD_READ_CAPACITY, 0x00, 0x00, 0x00, 0x00, 
                                               0x00, 0x00, 0x00, 0x00, 0x00};
    m_CBW.dCBWDataTraansferLength = 8;
    m_CBW.bmCBWFlags = 0x80; // data In
    CommandTransport(cdb, sizeof(cdb));

    uint8_t buf[8];
    int rc = _bulkRecv(buf, sizeof(buf));
    TEST_ASSERT(rc >= 0);
    DBG_HEX(buf, sizeof(buf));

    StatusTransport();
    TEST_ASSERT(m_CSW.bCSWStatus == 0x00);
    
    m_numBlocks = BE32(buf);
    m_BlockSize = BE32(buf+4);
    DBG("m_numBlocks=%d m_BlockSize=%d\n", m_numBlocks, m_BlockSize);
    TEST_ASSERT(m_BlockSize == 512);
    TEST_ASSERT(m_numBlocks > 0);
    return 0;
}

int UsbFlashDrive::Inquire()
{
    const uint8_t cdb[6] = {SCSI_CMD_INQUIRY, 0x00, 0x00, 0x00, 36, 0x00};
    m_CBW.dCBWDataTraansferLength = 36;
    m_CBW.bmCBWFlags = 0x80; // data In
    CommandTransport(cdb, sizeof(cdb));

    uint8_t buf[36];
    int rc = _bulkRecv(buf, sizeof(buf));
    TEST_ASSERT(rc >= 0);
    DBG_HEX(buf, sizeof(buf));

    StatusTransport();
    return 0;
}

int UsbFlashDrive::MS_BulkRecv(uint32_t block_number, int num_blocks, uint8_t* user_buffer)
{
    TEST_ASSERT(m_BlockSize == 512);
    TEST_ASSERT(num_blocks == 1);
    TEST_ASSERT(user_buffer);
    uint8_t cdb[10] = {SCSI_CMD_READ_10, 0x00, 0x00, 0x00, 0x00, 
                                   0x00, 0x00, 0x00, 0x00, 0x00};
    BE32(block_number, cdb+2);
    BE16(num_blocks, cdb+7);
    uint32_t len = m_BlockSize * num_blocks;
    TEST_ASSERT(len <= 512);
    m_CBW.dCBWDataTraansferLength = len;
    m_CBW.bmCBWFlags = 0x80; // data In
    CommandTransport(cdb, sizeof(cdb));

    int ret = _bulkRecv(user_buffer, len);
    //DBG_HEX(user_buffer, len);

    StatusTransport();
    TEST_ASSERT(m_CSW.bCSWStatus == 0x00);
    return ret;
}

int UsbFlashDrive::MS_BulkSend(uint32_t block_number, int num_blocks, const uint8_t* user_buffer)
{
#ifdef WRITE_PROTECT
    return 0;
#else
    TEST_ASSERT(num_blocks == 1);
    TEST_ASSERT(user_buffer);
    uint8_t cdb[10] = {SCSI_CMD_WRITE_10, 0x00, 0x00, 0x00, 0x00, 
                                    0x00, 0x00, 0x00, 0x00, 0x00};
    BE32(block_number, cdb+2);
    BE16(num_blocks, cdb+7);
    uint32_t len = m_BlockSize * num_blocks;
    TEST_ASSERT(len <= 512);
    m_CBW.dCBWDataTraansferLength = len;
    m_CBW.bmCBWFlags = 0x00; // data Out
    CommandTransport(cdb, sizeof(cdb));

    int ret = _bulkSend(user_buffer, len);
    //DBG_HEX(user_buffer, len);

    StatusTransport();
    TEST_ASSERT(m_CSW.bCSWStatus == 0x00);
    return ret;
#endif //WRITE_PROTECT    
}

int UsbFlashDrive::CommandTransport(const uint8_t* cdb, int size)
{
    TEST_ASSERT(cdb);
    TEST_ASSERT(size >= 6);
    TEST_ASSERT(size <= 16);
    m_CBW.bCBWLUN = m_lun;
    m_CBW.bCBWCBLength = size;
    memcpy(m_CBW.CBWCB, cdb, size);

    m_CBW.dCBWSignature = 0x43425355;
    m_CBW.dCBWTag = m_tag++;
    m_CBW.bCBWLUN = 0;
    //DBG_HEX((uint8_t*)&m_CBW, sizeof(CBW));
    int rc = _bulkSend((uint8_t*)&m_CBW, sizeof(CBW));
    return rc;
}

int UsbFlashDrive::StatusTransport()
{
    TEST_ASSERT(sizeof(CSW) == 13);
    int rc = _bulkRecv((uint8_t*)&m_CSW, sizeof(CSW));
    //DBG_HEX((uint8_t*)&m_CSW, sizeof(CSW));
    TEST_ASSERT(m_CSW.dCSWSignature == 0x53425355);
    TEST_ASSERT(m_CSW.dCSWTag == m_CBW.dCBWTag);
    TEST_ASSERT(m_CSW.dCSWDataResidue == 0);
    return rc;
}

int UsbFlashDrive::_bulkRecv(uint8_t* buf, int size)
{
    TEST_ASSERT(m_pEpBulkIn);
    int ret = m_pEpBulkIn->bulkReceive(buf, size);
    return ret;
}

int UsbFlashDrive::_bulkSend(const uint8_t* buf, int size)
{
    TEST_ASSERT(m_pEpBulkOut);
    int ret = m_pEpBulkOut->bulkSend(buf, size);
    return ret;
}
