#ifndef SNIFFER_H
#define SNIFFER_H

#include "mbed.h"

#include "util/types.h"
#include "net/net.h"

#include <cstdio>
#include <cstring>
#include <functional>

/**
  \file sniffer.h
  \brief Ethernet packet Sniffer
  
  This file is the bread and butter of the NetTool; it processes and constructs ethernet frames
  on a bitwise level.
*/

template <class Arg1, class Arg2, class Result>
class handler
{
public:
  virtual inline Result operator() (Arg1 x, Arg2 y) const {};
};

template <class Arg1, class Arg2, class Result>
class function_handler
: public handler <Arg1,Arg2,Result>
{
protected:
  Result (*pfunc)(Arg1,Arg2);
public:
  explicit inline function_handler ( Result (*f)(Arg1,Arg2) ) : pfunc (f) {}
  virtual inline Result operator() (Arg1 x, Arg2 y) const { return pfunc(x,y); }
};

template <class Type, class Arg1, class Arg2, class Result>
class member_handler
: public handler <Arg1,Arg2,Result>
{
protected:
  Type *inst;
  Result (Type::*pfunc)(Arg1,Arg2);
public:
  explicit inline member_handler ( Type *i, Result (Type::*f)(Arg1,Arg2) ) : inst(i), pfunc (f) {}
  virtual inline Result operator() (Arg1 x, Arg2 y) const { return (inst->*pfunc)(x,y); }
};

/// Demo - Ethernet Packet Sniffer
class Sniffer {
public:
  Ethernet_MAC mac;
  
private:
  // Ethernet interface
  Ethernet     eth;
  IP_Address   addr;
  
  // Status LEDs
  DigitalOut linked;
  DigitalOut received;

  // Frame data (big enough for largest ethernet frame)
  int  frame_size;
  char frame[0x600];
  
  // Outgoing frames
  char outframe[0x600];
  Ethernet_FrameHeader *outframe_header;

public:  
  /// Ethernet Frame Header (incoming)
  Ethernet_FrameHeader *frame_header;
  
  /// IP Packet Header (incoming)
  IP_PacketHeader *ip_packet;
  
  /// ARP Packet (incoming)
  ARP_Packet *arp_packet;
  
  /// TCP Packet (incoming)
  TCP_SegmentHeader *tcp_packet;
  
  /// UDP Packet (incoming)
  UDP_Packet *udp_packet;
  
  /// ICMP Packet (incoming)
  ICMP_Packet *icmp_packet;
  
  /// Generic - total data bytes
  unsigned int data_bytes;

public:
  /// Constructor
  inline Sniffer()
  : linked(LED1), received(LED2)
  {
    eth.set_link(Ethernet::AutoNegotiate);
    eth.address((char *)mac.octet);
  }
  
  /// Inject the raw ethernet frame
  inline bool inject(void *data, unsigned int bytes)
  {
    // Send the packet
    eth.write((char*)data, bytes);
    int send_status = eth.send();

    //decode_ethernet(data);
    
    return send_status;
  }
  
  /// Inject the raw payload into an ethernet frame with the given destination and ethertype
  inline bool inject(Ethernet_MAC dest, u16 ethertype, void *packet, unsigned int bytes)
  {
    memset(outframe, 0x00, bytes);
    
    outframe_header = (Ethernet_FrameHeader*)outframe;
    
    // Set the ethernet frame source
    memcpy(&outframe_header->source, mac.octet, 6);
    
    // Set the ethernet frame destination
    outframe_header->destination = dest;
    
    // Set the ethernet ethertype
    outframe_header->ethertype = ethertype;
    
    // Make sure the payload won't be too large
    if (sizeof(Ethernet_FrameHeader) + bytes > sizeof(outframe))
    {
      printf("ERROR: Attempt to inject packet failed; Payload size of %d is too large /n", bytes);
      return false;
    }
    
    // Set the payload
    memcpy(outframe_header->payload, packet, bytes);
    fix_endian_ethernet(outframe_header);
    
    // Send the packet
    eth.write(outframe, sizeof(Ethernet_FrameHeader) + bytes);
    int send_status = eth.send();

    //decode_ethernet(outframe);

    return send_status;
  }
  
  /// Wait until there is more data to receive
  inline void wait_for_data()
  { 
    while (true)
    {
      wait(0.0001);
      
      if (!(linked = eth.link()))
        continue;
      
      received = (frame_size = eth.receive());
      if (!frame_size)
        continue;
      
      eth.read(frame, frame_size);
      break;
    }
  }

  /// Wait for an ethernet frame (will be stored in appropriate class member pointers)
  inline void next()
  {
    wait_for_data();
    
    // Zero out all of the packet pointers
    frame_header = NULL;
    arp_packet = NULL;
    icmp_packet = NULL;
    tcp_packet = NULL;
    udp_packet = NULL;
    data_bytes = 0;
    
    decode_ethernet(frame);
  }
  
  /// Decode the given ethernet frame
  inline void decode_ethernet(void *frame)
  {
    Ethernet_FrameHeader *header = frame_header = (Ethernet_FrameHeader*)frame;
    fix_endian_ethernet(header);
    
    switch (header->ethertype)
    {
      case ETHERTYPE_IPV4:
      case ETHERTYPE_IPV6:
        decode_ip((IP_PacketHeader*)header->payload);
        break;
      case ETHERTYPE_ARP:
        decode_arp((ARP_Packet*)header->payload);
        break;
      default:
        break; // Unknown ethertype
    }
  }
  
  /// Decode the given ARP packet
  inline void decode_arp(ARP_Packet *packet)
  {
    fix_endian_arp(packet);
    if (packet->hardware_type != 0x0001 || packet->protocol_type != 0x0800) return;
    arp_packet = packet;
  }
  
  /// Decode the given IPv4 packet
  inline void decode_ip(IP_PacketHeader *packet)
  {
    u16 chk = checksum(packet, sizeof(IP_PacketHeader), &packet->header_checksum, 2);
    fix_endian_ip(packet);
    ip_packet = packet;
  
    if (packet->version != 4) return;
    
    data_bytes = packet->packet_bytes;
    data_bytes -= sizeof(IP_PacketHeader);
    
    if (packet->protocol == IPPROTO_UDP)
    {
      UDP_Packet *segment = udp_packet = (UDP_Packet*)packet->data;
      fix_endian_udp(segment);
      data_bytes -= sizeof(UDP_Packet);
    }
    else if (packet->protocol == IPPROTO_ICMP)
    {
      ICMP_Packet *segment = icmp_packet = (ICMP_Packet *)packet->data;
      fix_endian_icmp(segment);
      data_bytes -= sizeof(ICMP_Packet);
    }
    else if (packet->protocol == IPPROTO_TCP)
    {
      TCP_SegmentHeader *segment = tcp_packet = (TCP_SegmentHeader*)packet->data;
      fix_endian_tcp(segment);
      data_bytes -= sizeof(TCP_SegmentHeader);
      dispatch_tcp(segment,data_bytes);
    }
  }
  
  handler<TCP_SegmentHeader*,u32,void> *tcp_handler;
  inline void dispatch_tcp(TCP_SegmentHeader *tcp_packet, u32 data_bytes)
  {
    if (tcp_handler) (*tcp_handler)(tcp_packet,data_bytes);
  }
  
  /// Attach a member function to be called on all TCP packets
  template <class T>
  inline void attach_tcp(T *inst, void (T::*func)(TCP_SegmentHeader *tcp_packet, u32 data_bytes))
  {
    tcp_handler = new member_handler<T,TCP_SegmentHeader*,u32,void>(inst, func);
  }
  
  /// Attach a non-member function to be called on all TCP packets
  inline void attach_tcp(void (*func)(TCP_SegmentHeader *tcp_packet, u32 data_bytes))
  {
    tcp_handler = new function_handler<TCP_SegmentHeader*,u32,void>(func);
  }
};

#endif // SNIFFER_H