/*******************************************************************************
* Copyright (C) Maxim Integrated Products, Inc., All Rights Reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL MAXIM INTEGRATED BE LIABLE FOR ANY CLAIM, DAMAGES
* OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*
* Except as contained in this notice, the name of Maxim Integrated
* Products, Inc. shall not be used except as stated in the Maxim Integrated
* Products, Inc. Branding Policy.
*
* The mere transfer of this software does not imply any licenses
* of trade secrets, proprietary technology, copyrights, patents,
* trademarks, maskwork rights, or any other form of intellectual
* property whatsoever. Maxim Integrated Products, Inc. retains all
* ownership rights.
*******************************************************************************/

#include <stdint.h>
#include <stdio.h>
#include <cassert>
#include <cstdio>
#include <string>
#include <MaximInterfaceCore/HexString.hpp>
#include <MaximInterfaceCore/span.hpp>
#include <MaximInterfaceDevices/DS28C36_DS2476.hpp>
#include <rapidjson/stringbuffer.h>
#include <rapidjson/writer.h>
#include "DisplayGraphicWindow.hpp"
#include "DisplayIdWindow.hpp"
#include "ErrorWindow.hpp"
#include "Image.hpp"
#include "MakeFunction.hpp"
#include "NormalOperationWindow.hpp"
#include "Text.hpp"
#include "WindowManager.hpp"

#define TRY MaximInterfaceCore_TRY
#define TRY_VALUE MaximInterfaceCore_TRY_VALUE

namespace Core = MaximInterfaceCore;
using MaximInterfaceDevices::DS2476;

extern DS2476 coproc;
extern SensorNode sensorNode;
extern std::string webId;

extern "C" void ComputeSHA256(unsigned char * message, short length,
                              unsigned short skipconst, unsigned short reverse,
                              unsigned char * digest);

// Default allocation size for rapidjson.
static const size_t defaultChunkSize = 256;

// Number of decimal places to use when writing JSON.
static const int jsonMaxDecimalPlaces = 2;

// Separate multiple JSON commands received on the socket.
// Returns a list of begin and end iterators within the input message.
static std::vector<Core::span<const char> >
separateCommands(Core::span<const char> receivedData) {
  std::vector<Core::span<const char> > commands;
  int counter = 0;
  Core::span<const char>::index_type beginIdx;
  for (Core::span<const char>::index_type i = 0; i < receivedData.size(); ++i) {
    if (receivedData[i] == '{') {
      if (counter == 0) {
        beginIdx = i;
      }
      ++counter;
    } else if (receivedData[i] == '}') {
      if (counter > 0) {
        --counter;
        if (counter == 0) {
          commands.push_back(
              Core::make_span(&receivedData[beginIdx], &receivedData[i + 1]));
        }
      }
    }
  }
  return commands;
}

Core::Result<void>
NormalOperationWindow::addCommandChallenge(rapidjson::Document & document) {
  TRY(coproc.readRng(commandChallenge));
  document.AddMember("challenge",
                     rapidjson::Value(toHexString(commandChallenge).c_str(),
                                      document.GetAllocator())
                         .Move(),
                     document.GetAllocator());
  return Core::none;
}

Core::Result<void>
NormalOperationWindow::signData(rapidjson::Document & document,
                                bool validSignature,
                                const std::vector<uint8_t> & challenge) {
  // Move contents of the document to a new location, and create an empty object
  // in the document.
  rapidjson::Value data(rapidjson::kObjectType);
  data.Swap(document);
  // Convert data to a string and generate a signature from that string.
  rapidjson::StringBuffer dataBuffer;
  rapidjson::Writer<rapidjson::StringBuffer> writer(dataBuffer);
  writer.SetMaxDecimalPlaces(jsonMaxDecimalPlaces);
  data.Accept(writer);
  std::vector<uint8_t> signDataBuffer(
      dataBuffer.GetString(), dataBuffer.GetString() + dataBuffer.GetLength());
  signDataBuffer.insert(signDataBuffer.end(), challenge.begin(),
                        challenge.end());
  uint8_t hash[32];
  ComputeSHA256(&signDataBuffer[0], signDataBuffer.size(), false, false, hash);
  TRY(coproc.writeBuffer(hash));
  Core::Ecc256::Signature::array signatureBuffer;
  TRY_VALUE(signatureBuffer, coproc.generateEcdsaSignature(DS2476::KeyNumA));
  if (!validSignature) {
    ++signatureBuffer.r[0];
  }
  // Construct the final document with the original data and the generated
  // signature.
  rapidjson::Value signature(rapidjson::kObjectType);
  signature.AddMember(
      "r",
      rapidjson::Value(Core::toHexString(signatureBuffer.r).c_str(),
                       document.GetAllocator())
          .Move(),
      document.GetAllocator());
  signature.AddMember(
      "s",
      rapidjson::Value(Core::toHexString(signatureBuffer.s).c_str(),
                       document.GetAllocator())
          .Move(),
      document.GetAllocator());
  document.AddMember("data", data, document.GetAllocator());
  document.AddMember("signature", signature, document.GetAllocator());
  return Core::none;
}

Core::Result<void> NormalOperationWindow::finalizeResponse(
    rapidjson::Document & document, bool validSignature,
    const std::vector<uint8_t> & responseChallenge) {
  TRY(addCommandChallenge(document));
  TRY(signData(document, validSignature, responseChallenge));
  return Core::none;
}

Core::Result<void>
NormalOperationWindow::verifySignedData(rapidjson::Document & signedData,
                                        Core::span<const char> verifyDataIn) {
  using rapidjson::Value;
  using std::string;

  // Parse string and validate object schema.
  string verifyData(verifyDataIn.begin(), verifyDataIn.end());
  signedData.Parse(verifyData.c_str());
  if (!(signedData.IsObject() && signedData.HasMember("data") &&
        signedData.HasMember("signature"))) {
    signedData.RemoveAllMembers();
    return DS2476::AuthenticationError;
  }
  Value & data = signedData["data"];
  const Value & signature = signedData["signature"];
  if (!(data.IsObject() && signature.IsObject() && signature.HasMember("r") &&
        signature.HasMember("s"))) {
    signedData.RemoveAllMembers();
    return DS2476::AuthenticationError;
  }
  const Value & signatureR = signature["r"];
  const Value & signatureS = signature["s"];
  if (!(signatureR.IsString() && signatureS.IsString())) {
    signedData.RemoveAllMembers();
    return DS2476::AuthenticationError;
  }

  // Parse signature.
  Core::Optional<std::vector<uint8_t> > parsedBytes = Core::fromHexString(
      Core::make_span(signatureR.GetString(), signatureR.GetStringLength()));
  Core::Ecc256::Signature::array signatureBuffer;
  if (!(parsedBytes && parsedBytes->size() == signatureBuffer.r.size())) {
    signedData.RemoveAllMembers();
    return DS2476::AuthenticationError;
  }
  std::copy(parsedBytes->begin(), parsedBytes->end(),
            signatureBuffer.r.begin());
  parsedBytes = Core::fromHexString(
      Core::make_span(signatureS.GetString(), signatureS.GetStringLength()));
  if (!(parsedBytes && parsedBytes->size() == signatureBuffer.s.size())) {
    signedData.RemoveAllMembers();
    return DS2476::AuthenticationError;
  }
  std::copy(parsedBytes->begin(), parsedBytes->end(),
            signatureBuffer.s.begin());

  // Get data to hash.
  // Need to use string searching here since there isn't currently a way to
  // access raw elements in rapidjson, and creating another copy of the data
  // might consume too much memory.
  const string rawDataSearch("\"data\":");
  string::size_type rawDataBegin = verifyData.find(rawDataSearch);
  if ((rawDataBegin == string::npos) ||
      ((rawDataBegin + rawDataSearch.size()) >= verifyData.size())) {
    signedData.RemoveAllMembers();
    return DS2476::AuthenticationError;
  }
  rawDataBegin += rawDataSearch.size();
  string::size_type rawDataEnd =
      verifyData.find(",\"signature\"", rawDataBegin);
  if (rawDataEnd == string::npos) {
    signedData.RemoveAllMembers();
    return DS2476::AuthenticationError;
  }
  verifyData.erase(rawDataEnd);
  verifyData.erase(0, rawDataBegin);
  // Add in command challenge to data that will be verified.
  verifyData.append(commandChallenge.begin(), commandChallenge.end());

  // Compute hash of the data.
  Core::Result<void> result = computeMultiblockHash(
      coproc,
      Core::make_span(reinterpret_cast<const uint8_t *>(verifyData.data()),
                      verifyData.size()));
  if (!result) {
    signedData.RemoveAllMembers();
    return result;
  }
  // Verify signature.
  result = coproc.verifyEcdsaSignature(DS2476::KeyNumC, DS2476::THASH,
                                       signatureBuffer);
  if (!result) {
    signedData.RemoveAllMembers();
    return result;
  }

  // Strip signing information from document.
  rapidjson::Value swapObject(rapidjson::kObjectType);
  swapObject.Swap(data);
  swapObject.Swap(signedData);
  return Core::none;
}

void NormalOperationWindow::sendJson(const rapidjson::Value & document) {
  rapidjson::StringBuffer buffer;
  rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
  writer.SetMaxDecimalPlaces(jsonMaxDecimalPlaces);
  document.Accept(writer);
  socket->send(buffer.GetString(), buffer.GetLength());
}

void NormalOperationWindow::sendMessage(const char * message) {
  rapidjson::MemoryPoolAllocator<> allocator(defaultChunkSize);
  rapidjson::Document document(rapidjson::kObjectType, &allocator);
  document.AddMember("message", rapidjson::StringRef(message),
                     document.GetAllocator());
  sendJson(document);
}

static std::string getValidSignatureButtonText(bool validSignature) {
  return validSignature ? "Use invalid sig." : "Use valid sig.";
}

void NormalOperationWindow::showWebId(Button *) {
  if (windowManager()) {
    std::auto_ptr<Window> window(
        new DisplayIdWindow(DisplayIdWindow::PopupMode));
    windowManager()->push(window);
  }
}

void NormalOperationWindow::toggleValidSignature(Button *) {
  validSignature = !validSignature;
  validSignatureButton.setText(getValidSignatureButtonText(validSignature));
}

NormalOperationWindow::NormalOperationWindow(std::auto_ptr<TCPSocket> & socket)
    : socket(socket) /* Move construct */, sendChallenge(true),
      validSignature(true), lastSensorNodeState(SensorNode::Disconnected),
      lastObjectTemp(0), lastAmbientTemp(0) {
  assert(this->socket.get());

  validSignatureButton.setParent(this);
  validSignatureButton.setText(getValidSignatureButtonText(validSignature));
  validSignatureButton.setClickedHandler(
      makeFunction(this, &NormalOperationWindow::toggleValidSignature));
  showWebIdButton.setParent(this);
  showWebIdButton.setText("Show web ID");
  showWebIdButton.setClickedHandler(
      makeFunction(this, &NormalOperationWindow::showWebId));
  validSignatureButton.setFocused();
}

NormalOperationWindow::Result NormalOperationWindow::sendStatus(
    const std::vector<uint8_t> & responseChallenge) {
  rapidjson::MemoryPoolAllocator<> allocator(defaultChunkSize);
  rapidjson::Document document(rapidjson::kObjectType, &allocator);

  // Insert Web ID.
  document.AddMember("id", rapidjson::StringRef(webId.c_str()),
                     document.GetAllocator());

  // Insert device public key.
  rapidjson::Value publicKey(rapidjson::kObjectType);
  Core::Result<DS2476::Page::array> page =
      coproc.readMemory(DS2476::publicKeyAxPage);
  if (!page) {
    if (windowManager()) {
      std::auto_ptr<Window> window(
          new ErrorWindow("Failed to read Public Key A (x)"));
      windowManager()->push(window);
    }
    return WindowsChanged;
  }
  publicKey.AddMember("x",
                      rapidjson::Value(toHexString(page.value()).c_str(),
                                       document.GetAllocator())
                          .Move(),
                      document.GetAllocator());
  page = coproc.readMemory(DS2476::publicKeyAyPage);
  if (!page) {
    if (windowManager()) {
      std::auto_ptr<Window> window(
          new ErrorWindow("Failed to read Public Key A (y)"));
      windowManager()->push(window);
    }
    return WindowsChanged;
  }
  publicKey.AddMember("y",
                      rapidjson::Value(toHexString(page.value()).c_str(),
                                       document.GetAllocator())
                          .Move(),
                      document.GetAllocator());
  document.AddMember("publicKey", publicKey, document.GetAllocator());

  // Insert device certificate.
  rapidjson::Value certificate(rapidjson::kObjectType);
  page = coproc.readMemory(14);
  if (!page) {
    if (windowManager()) {
      std::auto_ptr<Window> window(
          new ErrorWindow("Failed to read User Data 14"));
      windowManager()->push(window);
    }
    return WindowsChanged;
  }
  certificate.AddMember("r",
                        rapidjson::Value(toHexString(page.value()).c_str(),
                                         document.GetAllocator())
                            .Move(),
                        document.GetAllocator());
  page = coproc.readMemory(15);
  if (!page) {
    if (windowManager()) {
      std::auto_ptr<Window> window(
          new ErrorWindow("Failed to read User Data 15"));
      windowManager()->push(window);
    }
    return WindowsChanged;
  }
  certificate.AddMember("s",
                        rapidjson::Value(toHexString(page.value()).c_str(),
                                         document.GetAllocator())
                            .Move(),
                        document.GetAllocator());
  document.AddMember("certificate", certificate, document.GetAllocator());

  // Sign data and transmit to server.
  if (!finalizeResponse(document, validSignature, responseChallenge)) {
    if (windowManager()) {
      std::auto_ptr<Window> window(new ErrorWindow("Failed to sign data"));
      windowManager()->push(window);
    }
    return WindowsChanged;
  }
  sendJson(document);
  return NoChange;
}

NormalOperationWindow::Result NormalOperationWindow::sendObjectTemp(
    const std::vector<uint8_t> & responseChallenge) {
  rapidjson::MemoryPoolAllocator<> allocator(defaultChunkSize);
  rapidjson::Document document(rapidjson::kObjectType, &allocator);

  // Read object temperature and add to document.
  double objectTemp;
  if (const Core::Result<double> sensorResult =
          sensorNode.readTemp(SensorNode::ObjectTemp)) {
    objectTemp = sensorResult.value();
  } else {
    if (windowManager()) {
      std::auto_ptr<Window> window(
          new ErrorWindow("Failed to read object temperature"));
      windowManager()->push(window);
    }
    return WindowsChanged;
  }
  document.AddMember("objectTemp", objectTemp, document.GetAllocator());

  // Sign data and transmit to server.
  if (!finalizeResponse(document, validSignature, responseChallenge)) {
    if (windowManager()) {
      std::auto_ptr<Window> window(new ErrorWindow("Failed to sign data"));
      windowManager()->push(window);
    }
    return WindowsChanged;
  }
  sendJson(document);

  lastObjectTemp = objectTemp;
  return NoChange;
}

NormalOperationWindow::Result NormalOperationWindow::sendAmbientTemp(
    const std::vector<uint8_t> & responseChallenge) {
  rapidjson::MemoryPoolAllocator<> allocator(defaultChunkSize);
  rapidjson::Document document(rapidjson::kObjectType, &allocator);

  // Read ambient temperature and add to document.
  double ambientTemp;
  if (const Core::Result<double> sensorResult =
          sensorNode.readTemp(SensorNode::AmbientTemp)) {
    ambientTemp = sensorResult.value();
  } else {
    if (windowManager()) {
      std::auto_ptr<Window> window(
          new ErrorWindow("Failed to read ambient temperature"));
      windowManager()->push(window);
    }
    return WindowsChanged;
  }
  document.AddMember("ambientTemp", ambientTemp, document.GetAllocator());

  // Sign data and transmit to server.
  if (!finalizeResponse(document, validSignature, responseChallenge)) {
    if (windowManager()) {
      std::auto_ptr<Window> window(new ErrorWindow("Failed to sign data"));
      windowManager()->push(window);
    }
    return WindowsChanged;
  }
  sendJson(document);

  lastAmbientTemp = ambientTemp;
  return NoChange;
}

void NormalOperationWindow::displayImage(
    const std::vector<uint8_t> & imageData) {
  if (windowManager()) {
    std::auto_ptr<Graphic> image(
        new Image(Bitmap(&imageData[0], imageData.size(), 64)));
    std::auto_ptr<Window> window(new DisplayGraphicWindow(image));
    windowManager()->push(window);
  }
}

NormalOperationWindow::Result
NormalOperationWindow::processReceivedData(size_t recvBufSize) {
  // Separate commands and process each one.
  const std::vector<Core::span<const char> > commands =
      separateCommands(Core::make_span(recvBuf, recvBufSize));
  for (std::vector<Core::span<const char> >::const_iterator it =
           commands.begin();
       it != commands.end(); ++it) {
    rapidjson::MemoryPoolAllocator<> allocator(defaultChunkSize);
    rapidjson::Document data(&allocator);
    // Verify command signature.
    const Core::Result<void> verifySignedResult = verifySignedData(data, *it);
    if (verifySignedResult) {
      // Verify command schema.
      sendMessage("Received data is authentic");
      if (data.IsObject() && data.HasMember("command")) {
        const rapidjson::Value & command = data["command"];
        if (command.IsString()) {
          // Parse challenge if included.
          std::vector<uint8_t> responseChallenge;
          if (data.HasMember("challenge")) {
            const rapidjson::Value & challenge = data["challenge"];
            if (challenge.IsString()) {
              responseChallenge =
                  Core::fromHexString(
                      Core::make_span(challenge.GetString(),
                                      challenge.GetStringLength()))
                      .valueOr(std::vector<uint8_t>());
            }
          }

          // Execute the command.
          if (command == "getStatus") {
            const Result result = sendStatus(responseChallenge);
            if (result != NoChange) {
              return result;
            }
          } else if (command == "readObjectTemp") {
            if ((lastSensorNodeState == SensorNode::ValidLaserDisabled) ||
                (lastSensorNodeState == SensorNode::ValidLaserEnabled)) {
              const Result result = sendObjectTemp(responseChallenge);
              if (result != NoChange) {
                return result;
              }
              invalidate();
            }
          } else if (command == "readAmbientTemp") {
            if ((lastSensorNodeState == SensorNode::ValidLaserDisabled) ||
                (lastSensorNodeState == SensorNode::ValidLaserEnabled)) {
              const Result result = sendAmbientTemp(responseChallenge);
              if (result != NoChange) {
                return result;
              }
              invalidate();
            }
          } else if (command == "enableModule") {
            if (lastSensorNodeState == SensorNode::ValidLaserDisabled) {
              if (!sensorNode.setLaserEnabled(
                      true, makeFunction(
                                this, &NormalOperationWindow::sendMessage))) {
                lastSensorNodeState = SensorNode::ValidLaserEnabled;
                invalidate();
              }
            }
          } else if (command == "disableModule") {
            if (lastSensorNodeState == SensorNode::ValidLaserEnabled) {
              if (!sensorNode.setLaserEnabled(
                      false, makeFunction(
                                 this, &NormalOperationWindow::sendMessage))) {
                lastSensorNodeState = SensorNode::ValidLaserDisabled;
                invalidate();
              }
            }
          } else if (command == "displayImage") {
            if (data.HasMember("image")) {
              const rapidjson::Value & image = data["image"];
              if (image.IsString()) {
                displayImage(Core::fromHexString(
                                 Core::make_span(image.GetString(),
                                                 image.GetStringLength()))
                                 .valueOr(std::vector<uint8_t>()));
                return WindowsChanged;
              }
            }
          }
        }
      }
    } else if (verifySignedResult.error() == DS2476::AuthenticationError) {
      const char message[] = "Received data is not authentic";
      sendMessage(message);
      std::auto_ptr<Graphic> messageText(new Text);
      Text & messageTextRef = *static_cast<Text *>(messageText.get());
      messageTextRef.setText(message);
      messageTextRef.setWordWrap(true);
      if (windowManager()) {
        std::auto_ptr<Window> window(new DisplayGraphicWindow(messageText));
        windowManager()->push(window);
      }
      return WindowsChanged;
    } else {
      const char message[] = "Unable to verify received data";
      sendMessage(message);
      if (windowManager()) {
        std::auto_ptr<Window> window(new ErrorWindow(message));
        windowManager()->push(window);
      }
      return WindowsChanged;
    }
  }
  return NoChange;
}

void NormalOperationWindow::resized() {
  showWebIdButton.resize(width(), showWebIdButton.preferredHeight());
  showWebIdButton.move(0, height() - showWebIdButton.height());
  validSignatureButton.resize(width(), validSignatureButton.preferredHeight());
  validSignatureButton.move(0, showWebIdButton.y() -
                                   validSignatureButton.height() - 1);
}

static std::string doubleToString(double input) {
  char inputString[8];
  snprintf(inputString, sizeof(inputString) / sizeof(inputString[0]), "%.2f",
           input);
  return std::string(inputString);
}

void NormalOperationWindow::doRender(Bitmap & bitmap, int xOffset,
                                     int yOffset) const {
  // Format current status text.
  std::string sensorNodeStateText;
  switch (lastSensorNodeState) {
  case SensorNode::Disconnected:
    sensorNodeStateText = "Disconnected";
    break;

  case SensorNode::Invalid:
    sensorNodeStateText = "Invalid";
    break;

  case SensorNode::ValidLaserDisabled:
    sensorNodeStateText = "Valid, laser disabled";
    break;

  case SensorNode::ValidLaserEnabled:
    sensorNodeStateText = "Valid, laser enabled";
    break;
  }

  Text description;
  description.setText("Object temp: " + doubleToString(lastObjectTemp) +
                      "\nAmbient temp: " + doubleToString(lastAmbientTemp) +
                      "\nSensor node: " + sensorNodeStateText);
  description.resize(width(), validSignatureButton.y());
  description.setWordWrap(true);
  description.render(bitmap, xOffset + x(), yOffset + y());
  validSignatureButton.render(bitmap, xOffset + x(), yOffset + y());
  showWebIdButton.render(bitmap, xOffset + x(), yOffset + y());
}

void NormalOperationWindow::updated() {
  // Detect sensor node.
  const SensorNode::State sensorNodeState = sensorNode.detect();
  if (sensorNodeState != lastSensorNodeState) {
    lastSensorNodeState = sensorNodeState;
    invalidate();
  }

  // Send challenge on first connection.
  if (sendChallenge) {
    rapidjson::MemoryPoolAllocator<> allocator(defaultChunkSize);
    rapidjson::Document document(rapidjson::kObjectType, &allocator);
    if (addCommandChallenge(document)) {
      sendJson(document);
      sendChallenge = false;
    }
  } else {
    // Process socket data.
    const int recvResult =
        socket->recv(recvBuf, sizeof(recvBuf) / sizeof(recvBuf[0]));
    if (recvResult > 0) {
      std::printf("%*s\n", recvResult, recvBuf);
      const Result result = processReceivedData(recvResult);
      if (result != NoChange)
        return;
    } else if (recvResult != NSAPI_ERROR_WOULD_BLOCK) {
      if (windowManager()) {
        std::auto_ptr<Window> window(new ErrorWindow("Socket receive failed"));
        windowManager()->push(window);
      }
      return;
    }
  }
}

bool NormalOperationWindow::doProcessKey(Key key) {
  bool handled;
  switch (key) {
  case UpKey:
    validSignatureButton.setFocused();
    handled = true;
    break;

  case DownKey:
    showWebIdButton.setFocused();
    handled = true;
    break;

  default:
    handled = false;
    break;
  }
  return handled;
}
