/* mbed NationZ I2C/SPI TPM 2.0 Library,
 * Copyright (c) 2015, Microsoft Coprporation Inc.
 * by Stefan Thom (LordOfDorks) StefanTh@Microsoft.com, Stefan@ThomsR.Us
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
 
 #include "NationZ_TPM20.h"
 
// Constructor for the I2C variant of the chip
NTZTPM20::NTZTPM20(
    PinName sda,
    PinName scl
    )
{
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZI2C.Init: ");
#endif
    m_SPITpmDev = NULL;
    m_SPICSTpmDev = NULL;
    m_I2CTpmDev = new I2C(sda, scl);
    m_I2CTpmDev->frequency(400000);
    m_ExclusiveAccess = false;
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("OK.\n\r");
#endif
}

// Constructor for the SPI variant of the chip
NTZTPM20::NTZTPM20(
    PinName mosi,
    PinName miso,
    PinName clk,
    PinName cs
    )
{
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZSPI.Init: ");
#endif
    m_I2CTpmDev = NULL;
    m_SPITpmDev = new SPI(mosi, miso, clk);
    m_SPITpmDev->format(8, 0);
    m_SPITpmDev->frequency(5000000);
    m_SPICSTpmDev = new DigitalOut(cs);
    *m_SPICSTpmDev = 1;
    m_ExclusiveAccess = false;
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("OK.\n\r");
#endif
}

// Release all held resources
NTZTPM20::~NTZTPM20()
{
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.Destroy: ");
#endif
    if(m_I2CTpmDev != NULL)
    {
        delete m_I2CTpmDev;
        m_I2CTpmDev = NULL;
    }
    if(m_SPITpmDev != NULL)
    {
        delete m_SPITpmDev;
        m_SPITpmDev = NULL;
    }
    if(m_SPICSTpmDev != NULL)
    {
        delete m_SPICSTpmDev;
        m_SPICSTpmDev = NULL;
    }
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("OK.\n\r");
#endif
}

uint32_t
NTZTPM20::Execute(
    uint8_t* pbCmd,
    uint32_t cbCmd,
    uint8_t* pbRsp,
    uint32_t cbRsp,
    uint32_t timeout
    )
{
    uint32_t result = 0;
    Timeout watchdog;

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.ExecuteWaitForAccess.");
#endif

    // Only one caller should be talking to the TPM at any given time
    while(m_ExclusiveAccess)
    {
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf(".");
#endif
        wait_us(500);
    }
    m_ExclusiveAccess = true;
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("OK\n\r");
#endif

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.SetupTimeout\n\r");
#endif
    // Setup TPM timeout
    m_TimeoutTriggered = false;
    watchdog.attach(this, &NTZTPM20::TimeoutTrigger, 0.0001 * timeout);

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.Execute: ");
    for(uint32_t n = 0; n < cbCmd; n++) printf("%02x ", pbCmd[n]);
    printf("\n\r");
#endif
    
    // Execute command on the TPM
    if(m_I2CTpmDev != NULL)
    {
        result = ExecuteI2C(pbCmd, cbCmd, pbRsp, cbRsp);
    }
    else if(m_SPITpmDev != NULL)
    {
        result = ExecuteSPI(pbCmd, cbCmd, pbRsp, cbRsp);
    }

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.Response: ");
    for(uint32_t n = 0; n < result; n++) printf("%02x ", pbRsp[n]);
    printf("\n\r");
#endif

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.CancelTimeout\n\r");
#endif
    // Cleanup
    watchdog.detach();
    m_ExclusiveAccess = false;
    return result;
}

uint32_t
NTZTPM20::ParseResponseHeader(
    uint8_t* pbRsp,
    uint32_t rspLen,
    uint16_t* rspTag,
    uint32_t* rspSize,
    uint32_t* cursor
    )
{
    uint32_t rspResponseCode = 0;

    // Check that the response header is well formatted
    if(rspLen < (sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint32_t)))
    {
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf("NTZ.ResponseHdr.rspLen = 0x%08x\n\r", rspLen);
#endif
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }
    
    // Read the header components
    *rspTag = BYTEARRAY_TO_UINT16(pbRsp, *cursor);
    *cursor += sizeof(uint16_t);
    *rspSize = BYTEARRAY_TO_UINT32(pbRsp, *cursor);
    *cursor += sizeof(uint32_t);
    rspResponseCode = BYTEARRAY_TO_UINT32(pbRsp, *cursor);
    *cursor += sizeof(uint32_t);

    // Check the components
    if(((*rspTag != TPM_ST_NO_SESSIONS) && (*rspTag != TPM_ST_SESSIONS)) ||
       (*rspSize != rspLen))
    {
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf("NTZ.ResponseHdr.rspTag = 0x%04x.rspLen=0x%08x\n\r", *rspTag, rspLen);
#endif
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }

Cleanup:    
    return rspResponseCode;
}

uint32_t
NTZTPM20::TPM2_Startup(
    uint16_t startupType
    )
{
    uint32_t rspLen = 0;
    uint16_t rspTag = 0;
    uint32_t rspSize = 0;
    uint32_t rspResponseCode = 0;
    uint32_t cursor = 0;
    uint8_t tpmCmd[] = {0x80, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x01, 0x44, 0x00, 0x00};
    UINT16_TO_BYTEARRAY(startupType, tpmCmd, sizeof(tpmCmd) - 2);

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.TPM2_Startup(0x%04x)\n\r", startupType);
#endif

    if((rspLen = Execute(tpmCmd, sizeof(tpmCmd), tpmCmd, sizeof(tpmCmd), 10000)) == 0)
    {
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }
    if((rspResponseCode = ParseResponseHeader(tpmCmd, rspLen, &rspTag, &rspSize, &cursor)) != TPM_RC_SUCCESS)
    {
        goto Cleanup;
    }
    
    if(rspSize != 0x0000000a)
    {
        rspResponseCode = TPM_RC_FAILURE;
    }

Cleanup:    
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.TPM2_Startup.ResponseCode = 0x%08x\n\r", rspResponseCode);
#endif
    return rspResponseCode;
}

uint32_t
NTZTPM20::TPM2_Shutdown(
    uint16_t shutdownType
    )
{
    uint32_t rspLen = 0;
    uint16_t rspTag = 0;
    uint32_t rspSize = 0;
    uint32_t rspResponseCode = 0;
    uint32_t cursor = 0;
    uint8_t tpmCmd[] = {0x80, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x01, 0x45, 0x00, 0x00};
    UINT16_TO_BYTEARRAY(shutdownType, tpmCmd, sizeof(tpmCmd) - 2);

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.TPM2_Shutdown(0x%04x)\n\r", shutdownType);
#endif

    if((rspLen = Execute(tpmCmd, sizeof(tpmCmd), tpmCmd, sizeof(tpmCmd), 120000)) == 0)
    {
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }
    if((rspResponseCode = ParseResponseHeader(tpmCmd, rspLen, &rspTag, &rspSize, &cursor)) != TPM_RC_SUCCESS)
    {
        goto Cleanup;
    }

    if(rspSize != 0x0000000a)
    {
        rspResponseCode = TPM_RC_FAILURE;
    }

Cleanup:
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.TPM2_Shutdown.ResponseCode = 0x%08x\n\r", rspResponseCode);
#endif
    return rspResponseCode;
}

uint32_t
NTZTPM20::TPM2_SelfTest(
    uint8_t fullTest
    )
{
    uint32_t rspLen = 0;
    uint16_t rspTag = 0;
    uint32_t rspSize = 0;
    uint32_t rspResponseCode = 0;
    uint32_t cursor = 0;
    uint8_t tpmCmd[] = {0x80, 0x01, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x01, 0x43, 0x00};
    tpmCmd[sizeof(tpmCmd) - 1] = fullTest;

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.TPM2_SelfTest(0x%02x)\n\r", fullTest);
#endif

    if((rspLen = Execute(tpmCmd, sizeof(tpmCmd), tpmCmd, sizeof(tpmCmd), 5000)) == 0)
    {
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }
    if((rspResponseCode = ParseResponseHeader(tpmCmd, rspLen, &rspTag, &rspSize, &cursor)) != TPM_RC_SUCCESS)
    {
        goto Cleanup;
    }

    if(rspSize != 0x0000000a)
    {
        rspResponseCode = TPM_RC_FAILURE;
    }

Cleanup:
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.TPM2_SelfTest.ResponseCode = 0x%08x\n\r", rspResponseCode);
#endif
    return rspResponseCode;
}

uint32_t
NTZTPM20::TPM2_GetRandom(
    uint16_t bytesRequested,
    uint8_t* randomBytes
    )
{
    uint32_t cursor = 0;
    uint32_t rspLen = 0;
    uint16_t rspTag = 0;
    uint32_t rspSize = 0;
    uint32_t rspResponseCode = 0;
    uint32_t tpmMax = sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(uint16_t) + bytesRequested;
    uint8_t* tpmCmd = new uint8_t[tpmMax];
    uint16_t bytesReturned = 0;
    
    if(tpmCmd == NULL)
    {
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }

    // Build command
    UINT16_TO_BYTEARRAY(TPM_ST_NO_SESSIONS, tpmCmd, cursor);
    cursor += sizeof(uint16_t) + sizeof(cursor);
    UINT32_TO_BYTEARRAY(TPM_CC_GetRandom, tpmCmd, cursor);
    cursor += sizeof(TPM_CC_GetRandom);
    UINT16_TO_BYTEARRAY(bytesRequested, tpmCmd, cursor);
    cursor += sizeof(bytesRequested);
    UINT32_TO_BYTEARRAY(cursor, tpmCmd, sizeof(uint16_t));

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.TPM2_GetRandom(%d)\n\r", bytesRequested);
#endif

    if((rspLen = Execute(tpmCmd, cursor, tpmCmd, tpmMax, 5000)) == 0)
    {
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }
    cursor = 0;
    if((rspResponseCode = ParseResponseHeader(tpmCmd, rspLen, &rspTag, &rspSize, &cursor)) != TPM_RC_SUCCESS)
    {
        goto Cleanup;
    }

    if(rspSize != tpmMax)
    {
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }

    // Copy the random bytes out
    bytesReturned = BYTEARRAY_TO_UINT16(tpmCmd, cursor);
    cursor += sizeof(uint16_t);
    memcpy(randomBytes, &tpmCmd[cursor], (size_t)min(bytesReturned, bytesRequested));
 
 Cleanup:
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.TPM2_GetRandom.ResponseCode = 0x%08x\n\r", rspResponseCode);
#endif
    if(tpmCmd != NULL)
    {
        delete[] tpmCmd;
        tpmCmd = NULL;
    }
    return rspResponseCode;
}

uint32_t
NTZTPM20::TPM2_StirRandom(
    uint16_t inDataLen,
    uint8_t* inData
    )
{
    uint32_t cursor = 0;
    uint32_t rspLen = 0;
    uint16_t rspTag = 0;
    uint32_t rspSize = 0;
    uint32_t rspResponseCode = 0;
    uint32_t tpmMax = sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(uint16_t) + inDataLen;
    uint8_t* tpmCmd = new uint8_t[tpmMax];
    
    if(tpmCmd == NULL)
    {
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }

    // Build command
    UINT16_TO_BYTEARRAY(TPM_ST_NO_SESSIONS, tpmCmd, cursor);
    cursor += sizeof(uint16_t) + sizeof(cursor);
    UINT32_TO_BYTEARRAY(TPM_CC_StirRandom, tpmCmd, cursor);
    cursor += sizeof(TPM_CC_GetRandom);
    UINT16_TO_BYTEARRAY(inDataLen, tpmCmd, cursor);
    cursor += sizeof(inDataLen);
    memcpy(&tpmCmd[cursor], inData, inDataLen);
    cursor += inDataLen;
    UINT32_TO_BYTEARRAY(cursor, tpmCmd, sizeof(uint16_t));

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.TPM2_StirRandom(%d)\n\r", inDataLen);
#endif

    if((rspLen = Execute(tpmCmd, cursor, tpmCmd, tpmMax, 5000)) == 0)
    {
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }
    cursor = 0;
    if((rspResponseCode = ParseResponseHeader(tpmCmd, rspLen, &rspTag, &rspSize, &cursor)) != TPM_RC_SUCCESS)
    {
        goto Cleanup;
    }

    if(rspSize != 0x0000000a)
    {
        rspResponseCode = TPM_RC_FAILURE;
        goto Cleanup;
    }

 Cleanup:
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZ.TPM2_StirRandom.ResponseCode = 0x%08x\n\r", rspResponseCode);
#endif
    if(tpmCmd != NULL)
    {
        delete[] tpmCmd;
        tpmCmd = NULL;
    }
    return rspResponseCode;
}

void
NTZTPM20::TimeoutTrigger(
    void
    )
{
    m_TimeoutTriggered = true;
}

uint32_t
NTZTPM20::ExecuteI2C(
    uint8_t* pbCmd,
    uint32_t cbCmd,
    uint8_t* pbRsp,
    uint32_t cbRsp
    )
{
    uint32_t result = 0;
    uint32_t waitPeriod = 0;
    uint8_t retry = 0;
    uint32_t index = 0;
    uint32_t rspLen = 0;

    // Send the command buffer. Since the TPM may be sleeping we have to retry until the command is accepted
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZI2C.DeviceWakeup.");
#endif
    while((m_I2CTpmDev->write(m_I2CDevice_Address, (const char *)pbCmd, cbCmd, false) != 0) && (!m_TimeoutTriggered))
    {
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf(".");
#endif
        wait_us(100);
    }
    
    // See if device wakeup timed out
    if(m_TimeoutTriggered)
    {
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf("Timeout\n\r");
#endif
        goto Cleanup;
    }
    
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("OK\n\rNTZI2C.SentBytes = %d\n\rNTZI2C.WaitForCompletion.", cbCmd);
#endif

    // Wait for the TPM to have a response and then read the TPM response header consisting of TAG, SIZE and RETURNCODE
    waitPeriod = 100;
    while((m_I2CTpmDev->read(m_I2CDevice_Address, (char *)&pbRsp[index], (sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint32_t)), false) != 0) && (!m_TimeoutTriggered))
    {
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf(".");
#endif
        wait_us(waitPeriod);

        // Progressively wait longer to give the TPM some time to think
        if(retry++ > 10)
        {
            waitPeriod *= 10;
            retry = 0;
        }
    }
    
    if(m_TimeoutTriggered)
    {
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf("Timeout\n\r");
#endif
        goto Cleanup;
    }

    // Move the write pointer
    index += (sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint32_t));
    
    // Get the response size from the buffer and see how much we can read
    rspLen = min(BYTEARRAY_TO_UINT32(pbRsp, sizeof(uint16_t)), cbRsp);
   
    // Read the remaining data from the TPM
    if((rspLen - index) > 0)
    {
        while((m_I2CTpmDev->read(m_I2CDevice_Address, (char *)&pbRsp[index], rspLen - index, false) != 0) && (!m_TimeoutTriggered))
        {
#ifdef TPM_NTZ_DEBUG_OUTPUT
            printf(".");
#endif
            wait_us(100);
        }
        if(m_TimeoutTriggered)
        {
#ifdef TPM_NTZ_DEBUG_OUTPUT
            printf("Timeout\n\r");
#endif
            goto Cleanup;
        }
    }

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("OK\n\rNTZI2C.ReadBytes = %d\r\n", rspLen);
#endif

    result = rspLen;

Cleanup:
    return result;
}

uint32_t
NTZTPM20::ExecuteSPI(
    uint8_t* pbCmd,
    uint32_t cbCmd,
    uint8_t* pbRsp,
    uint32_t cbRsp
    )
{
    uint32_t result = 0;
    uint32_t waitPeriod = 0;
    uint8_t retry = 0;
    uint8_t statusByte = 0;
    uint32_t index = 0;

    // Lock the SPI bus for this operation
    *m_SPICSTpmDev = 0;

#ifdef TPM_NTZ_DEBUG_OUTPUT    
    printf("NTZSPI.Wakeup.");
#endif

    // Wake the TPM up
    do
    {
        if((statusByte = m_SPITpmDev->write(0xAA)) != 0xAA)
        {
            wait_us(10);
        }
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf(".");
#endif
    } while((statusByte != 0xAA) && (!m_TimeoutTriggered));

    // If the operation timed out bail
    if((statusByte != 0xAA) && (m_TimeoutTriggered))
    {
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf("Timeout\r\n");
#endif
        goto Cleanup;
    }

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("OK\r\nNTZSPI.CommandReady\r\n");
#endif
    // Signal commandReady to the TPM
    m_SPITpmDev->write(0xA5);

    // Send the command buffer
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZSPI.SendBytes = %d\r\n", cbCmd);
#endif
    for(uint32_t n = 0; n < cbCmd; n ++)
    {
        m_SPITpmDev->write(pbCmd[n]);
    }

    // Release the bus while the TPM is thinking
    *m_SPICSTpmDev = 1;

    // Wait for a response
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZSPI.WaitForCompletion.");
#endif
    waitPeriod = 100;
    do
    {
        wait_us(waitPeriod);

        // Progressively wait longer to give the TPM some time to think
        if(retry++ > 10)
        {
            waitPeriod *= 10;
            retry = 0;
        }
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf(".");
#endif

        // Lock the bus for this operation
        *m_SPICSTpmDev = 0;
        
        // Poke the TPM so see if it is ready
        if((pbRsp[index] = m_SPITpmDev->write(0xAA)) != 0xAA)
        {
            index++;

            // Read the TPM response header
            for(; index < (sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint32_t)); index++)
            {
                pbRsp[index] = m_SPITpmDev->write(0xFF);
            }
            
            uint32_t rspLen = BYTEARRAY_TO_UINT32(pbRsp, sizeof(uint16_t));
            result = min(rspLen, cbRsp);

#ifdef TPM_NTZ_DEBUG_OUTPUT
            if(result < rspLen)
            {
                printf("NTZSPI.ResponseBufferTooShortBy = %d", rspLen - result);
            }
#endif

            // Read the remainder of the response if there is one
            for(; index < result; index++)
            {
                pbRsp[index] = m_SPITpmDev->write(0xFF);
            }
        }
       
        // release the bus
        *m_SPICSTpmDev = 1;
    }
    while((pbRsp[0] == 0xAA) && (!m_TimeoutTriggered));
        
    if((pbRsp[0] == 0xAA) && (m_TimeoutTriggered))
    {
#ifdef TPM_NTZ_DEBUG_OUTPUT
        printf("Timeout\r\n");
#endif
        goto Cleanup;
    }

#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("OK\n\rNTZSPI.ReadBytes = %d\n\r", index);
#endif

Cleanup:
    // Make sure to release the bus
    *m_SPICSTpmDev = 1;

    // Send the TPM back to sleep
#ifdef TPM_NTZ_DEBUG_OUTPUT
    printf("NTZSPI.Sleep\r\n");
#endif
    *m_SPICSTpmDev = 0;
    m_SPITpmDev->write(0x5A);
    *m_SPICSTpmDev = 1;
    return result;
}

