/* -*-  Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
 * Author: Yu Cao <caoyu08@csnet1.cs.tsinghua.edu.cn>
 */

#include "ns3/assert.h"
#include "ns3/log.h"
#include "ns3/nstime.h"
#include "ns3/boolean.h"
#include "ns3/object-vector.h"

#include "ns3/packet.h"
#include "ns3/node.h"
#include "ns3/simulator.h"
#include "ns3/ipv4-route.h"

#include "tcp-l4-protocol.h"
#include "nampt-l4-protocol.h"
#include "nampt-socket.h"
#include "tcp-header.h"
#include "ipv4-end-point-demux.h"
#include "ipv4-end-point.h"
#include "ipv4-l3-protocol.h"
#include "tcp-socket-factory-impl.h"
#include "tcp-newreno.h"
#include "rtt-estimator.h"

#include <vector>
#include <sstream>
#include <iomanip>

NS_LOG_COMPONENT_DEFINE ("NaMPTL4Protocol");

namespace ns3 {

NS_OBJECT_ENSURE_REGISTERED (NaMPTL4Protocol);

/*TypeId 
FlowItem::GetTypeId (void)
{
  static TypeId tid = TypeId ("ns3::FlowItem")
    .SetParent<Object> ()
    .AddConstructor<FlowItem> ()
    ;
  return tid;
}*/

FlowItem::FlowItem (const Ipv4Address& addr, const Ipv4Mask& mask)
{
  m_mask = mask;
  m_network = addr.CombineMask(mask);
  m_nFrom = 0;
  m_nTo = 0;
}

FlowItem::~FlowItem(void)
{
}

const double NATItem::POLLING_INTERVAL = 0.1;
const uint8_t NATItem::TTL = 3;

TypeId 
NATItem::GetTypeId (void)
{
  static TypeId tid = TypeId ("ns3::NATItem")
    .SetParent<Object> ()
    .AddConstructor<NATItem> ()
    ;
  return tid;
}

NATItem::NATItem (void)
{
}

NATItem::~NATItem (void)
{
}

/////////////////////////////////////////

TypeId 
NaMPTL4Protocol::GetTypeId (void)
{
	static TypeId tid = TypeId ("ns3::NaMPTL4Protocol")
		.SetParent<TcpL4Protocol> ()
		.AddConstructor<NaMPTL4Protocol> ()
		;
	return tid;
}

NaMPTL4Protocol::NaMPTL4Protocol ()
{
	NS_LOG_FUNCTION_NOARGS ();
	NS_LOG_LOGIC("Made a NaMPTL4Protocol "<<this);
	m_tokenSource = 1;
	m_enableDCNNAT = false;
	m_nNAT = 0;
	m_freeNAT = -1;
}

NaMPTL4Protocol::~NaMPTL4Protocol ()
{
	NS_LOG_FUNCTION_NOARGS ();
}

Ptr<Socket>
NaMPTL4Protocol::CreateSocket (TypeId socketTypeId)
{
	NS_LOG_FUNCTION_NOARGS ();

	/*
	ObjectFactory rttFactory;
	ObjectFactory socketFactory;
	rttFactory.SetTypeId(m_rttTypeId);
	socketFactory.SetTypeId(socketTypeId);
	Ptr<RttEstimator> rtt = rttFactory.Create<RttEstimator> ();
	Ptr<TcpSocketBase> socket = socketFactory.Create<TcpSocketBase> ();
	socket->SetNode (m_node);
	socket->SetTcp (this);
	socket->SetRtt (rtt);
	return socket;
	*/
	Ptr<NaMPTSocket> socket = CreateObject<NaMPTSocket>();
	if (socket != 0)
	{
		socket->SetNode (m_node);
		socket->SetTcp (this);
		socket->SetSocketTypeId (socketTypeId);
		socket->SetRttTypeId (m_rttTypeId);

		/*
		Note: the real socket pointed by m_firstSocket
		          is created when bind() is invoked.
		*/
		
		// draw back by parent class
		//m_sockets.push_back (DynamicCast<TcpSocketBase>(socket)); 
	}
	return socket;
}

Ptr<Socket>
NaMPTL4Protocol::CreateSocket (void)
{
	return CreateSocket (m_socketTypeId);
}

void
NaMPTL4Protocol::DoDispose (void)
{
  NS_LOG_FUNCTION_NOARGS ();
  TokenDB_t::iterator it;
  for(it = m_tokenDB.begin (); it != m_tokenDB.end (); it++)
    {
      it->second = 0;
    }
  m_tokenDB.clear ();

  m_enableDCNNAT = false;
  m_pollingEvent.Cancel();
  DCNNATTable_t::iterator nat = m_dcnNATTable.begin();
  for (; nat != m_dcnNATTable.end(); nat++)
  	{
  	  (*nat) = 0;
  	}
  m_dcnNATTable.clear();
  m_nNAT = 0;
  m_freeNAT = -1;

  FlowTable_t::iterator flow = m_flowTable.begin();
  for (; flow != m_flowTable.end(); flow++)
  	{
  	  (*flow) = 0;
  	}
  m_flowTable.clear();

  TcpL4Protocol::DoDispose ();
}

/* 
 * This method is called by AddAgregate and completes the aggregation
 * by setting the node in the NaMPT stack, link it to the ipv4 stack and 
 * adding TCP socket factory to the node.
 */
void
NaMPTL4Protocol::NotifyNewAggregate ()
{
	if (m_node == 0)
	{
		Ptr<Node> node = this->GetObject<Node> ();
		if (node != 0)
		{
			Ptr<Ipv4L3Protocol> ipv4 = this->GetObject<Ipv4L3Protocol> ();
			if (ipv4 != 0)
			{
				this->SetNode (node);
				ipv4->Insert (this);
				Ptr<TcpSocketFactoryImpl> tcpFactory = CreateObject<TcpSocketFactoryImpl> ();
				tcpFactory->SetTcp (this);
				node->AggregateObject (tcpFactory);
				this->SetDownTarget (MakeCallback(&Ipv4L3Protocol::Send, ipv4));
			}
		}
	}
	Object::NotifyNewAggregate ();
}

enum IpL4Protocol::RxStatus 
NaMPTL4Protocol::Receive (Ptr<Packet> p, Ipv4Header const &header, 
					Ptr<Ipv4Interface> incomingInterface)
{
  NaMPTHeader namptHeader;

  // DCN NAT
  if (m_enableDCNNAT || m_nNAT != 0)
  {
    p->RemoveHeader(namptHeader);
	int32_t index = namptHeader.GetDestinationPort() - 1;
	if (m_dcnNATTable[index]->m_ttl == 0) // invalid item
	  return IpL4Protocol::RX_ENDPOINT_UNREACH;
	m_dcnNATTable[index]->m_ttl = NATItem::TTL;
	
	Ipv4Header newHeader = header;
	uint32_t i = 0;
	for (; i < 2; i++)
	{
	  if (newHeader.GetSource().Get() == m_dcnNATTable[index]->m_addrs[i])
	    break;
	}
	if (i >= 2)
	  return IpL4Protocol::RX_ENDPOINT_UNREACH;
	if (m_dcnNATTable[index]->m_ports[i] == 0)
	  m_dcnNATTable[index]->m_ports[i] = namptHeader.GetSourcePort();
	else if (m_dcnNATTable[index]->m_ports[i] != namptHeader.GetSourcePort())
	  return IpL4Protocol::RX_ENDPOINT_UNREACH;
	if (namptHeader.GetFlags() & (TcpHeader::FIN|TcpHeader::RST) ) //  this flow is being closied.
	{
	  m_dcnNATTable[index]->m_flag++;
	  if (m_dcnNATTable[index]->m_flag == (uint8_t)3)
	    EraseFlow(Ipv4Address(m_dcnNATTable[index]->m_addrs[0]), Ipv4Address(m_dcnNATTable[index]->m_addrs[1]));
	}
	// modify addresses and ports
	i++; i %= 2;
	newHeader.SetSource(header.GetDestination());
	newHeader.SetDestination(Ipv4Address(m_dcnNATTable[index]->m_addrs[i]));
	namptHeader.SetSourcePort(index+1);
	namptHeader.SetDestinationPort(m_dcnNATTable[index]->m_ports[i]);
	p->AddHeader(namptHeader);
	// send out
	Socket::SocketErrno errno_;
	Ptr<Ipv4L3Protocol> ipv4 = m_node->GetObject<Ipv4L3Protocol> ();
	Ptr<Ipv4Route> route = ipv4->GetRoutingProtocol ()->RouteOutput (p, newHeader, 0, errno_);
	if (route == 0)
	  return IpL4Protocol::RX_ENDPOINT_UNREACH;
	ipv4->SendWithHeader(p, newHeader, route);
    return IpL4Protocol::RX_OK;
  }
  //=======================================

  p->PeekHeader(namptHeader);
  
  // SYN packet, so checking its incoming interface
  if (namptHeader.GetFlags() == TcpHeader::SYN)
  {
    // check incoming interface if peer is nampt-enalbled
    if (namptHeader.FindSignal(NAMPT_NONE, false) )
    {
	    bool bFound = false;
		for (uint32_t i = 0; i < incomingInterface->GetNAddresses(); i++)
		{
		  if (header.GetDestination() == incomingInterface->GetAddress(i).GetLocal() )
	      {
	        bFound = true;
			break;
	      }
		}
		if (!bFound)
			return IpL4Protocol::RX_ENDPOINT_UNREACH;
    }

	// if it wants to join a connection, look up TokenDB
    Ptr<NaMPTSignal> sig = namptHeader.FindSignal(NAMPT_JOIN, false);
    if (sig != 0)
    {
      Ptr<NaMPTSocket> sock = LookupToken(sig->sigData.join.peerToken);
  	  if (sock == 0)
  	    return IpL4Protocol::RX_ENDPOINT_UNREACH;
  	  sock->ReceiveSubflowJoin(p, header, incomingInterface);
  	  return IpL4Protocol::RX_OK;
    }
  }

  // forwarding ...
  return TcpL4Protocol::Receive(p, header, incomingInterface);
}


/*
  * Invoked by TcpSocketBase
*/
void 
NaMPTL4Protocol::SendPacket (Ptr<Packet> packet, const TcpHeader &outgoing,
                               Ipv4Address saddr, Ipv4Address daddr, Ptr<NetDevice> oif)
{
  NS_LOG_LOGIC("NaMPTL4Protocol " << this
  			<< " sending seq " << outgoing.GetSequenceNumber()
  			<< " ack " << outgoing.GetAckNumber()
  			<< " flags " << std::hex << (int)outgoing.GetFlags() << std::dec
  			<< " data size " << packet->GetSize());
  NS_LOG_FUNCTION (this << packet << saddr << daddr << oif);

  NaMPTHeader outgoingHeader (outgoing);
  
  // checksum
  if(Node::ChecksumEnabled ())
  {
    outgoingHeader.EnableChecksums();
  }
  outgoingHeader.InitializeChecksum(saddr, daddr, PROT_NUMBER);

  // insert mpc signal
  if (outgoingHeader.GetFlags()& TcpHeader::SYN)
  {
    // Find socket
    Ptr<NaMPTSocket> sk = LookupSocketFromTokenDB(
    						outgoingHeader.GetSourcePort(), saddr,
    						outgoingHeader.GetDestinationPort(), daddr);
	if (sk != 0)
    {
      NS_ASSERT(m_tokenDB[sk->GetLocalToken()] == sk);
      Ptr<NaMPTSignal> sig = outgoingHeader.AddSignal (NAMPT_MPC);
	  sig->sigData.mpc.localToken = sk->GetLocalToken();

	  /*
	   // for testing
          sig->sigData.mpc.localToken = 12345;
          sig = outgoingHeader.AddSignal (NAMPT_JOIN);
          sig->sigData.join.peerToken = 67890;
          sig = outgoingHeader.AddSignal (NAMPT_ADDR);
          sig->sigData.addr.sigSeq = 11111;
          sig->sigData.addr.ipv4Addr = Ipv4Address("1.2.3.4").Get();
          sig = outgoingHeader.AddSignal (NAMPT_DSN);
          sig->sigData.dsn.sigSeq = 22222;
          sig->sigData.dsn.dataSeq = 33333;
          sig->sigData.dsn.dataLen = 44444;
          sig->sigData.dsn.sfSeq = 55555;
          sig = outgoingHeader.AddSignal (NAMPT_FIN);
          sig->sigData.fin.sigSeq = 66666;
          sig = outgoingHeader.AddSignal (NAMPT_ACK);
          sig->sigData.ack.sigSeq = 77777;
          */  
    }
	
  }

  // deserialize
  packet->AddHeader (outgoingHeader);

  // routing
  Ptr<Ipv4L3Protocol> ipv4 = m_node->GetObject<Ipv4L3Protocol> ();
  if (ipv4 != 0)
  {
    Ipv4Header header;
    header.SetDestination (daddr);
    header.SetProtocol (PROT_NUMBER);
    Socket::SocketErrno errno_;
    Ptr<Ipv4Route> route;
    if (ipv4->GetRoutingProtocol () != 0)
    {
      route = ipv4->GetRoutingProtocol ()->RouteOutput (packet, header, oif, errno_);
    }
    else
    {
      NS_LOG_ERROR ("No IPV4 Routing Protocol");
      route = 0;
    }
    m_downTarget (packet, saddr, daddr, PROT_NUMBER, route);
  }
  else
    NS_FATAL_ERROR("Trying to use Tcp on a node without an Ipv4 interface");

}

/*
  * Invoked by NaMPTSubflow because NaMPTHeader is used.
*/
void 
NaMPTL4Protocol::SendPacket (Ptr<Packet> packet, NaMPTHeader &outgoing,
                               Ipv4Address saddr, Ipv4Address daddr, Ptr<NetDevice> oif)
{
  NS_LOG_LOGIC("NaMPTL4Protocol " << this
  			<< " sending seq " << outgoing.GetSequenceNumber()
  			<< " ack " << outgoing.GetAckNumber()
  			<< " flags " << std::hex << (int)outgoing.GetFlags() << std::dec
  			<< " data size " << packet->GetSize());
  NS_LOG_FUNCTION (this << packet << saddr << daddr << oif);
  
  // checksum
  if(Node::ChecksumEnabled ())
  {
    outgoing.EnableChecksums();
  }
  outgoing.InitializeChecksum(saddr, daddr, PROT_NUMBER);

  // deserialize
  packet->AddHeader (outgoing);

  // routing
  Ptr<Ipv4L3Protocol> ipv4 = m_node->GetObject<Ipv4L3Protocol> ();
  if (ipv4 != 0)
  {
    Ipv4Header header;
    header.SetDestination (daddr);
    header.SetProtocol (PROT_NUMBER);
    Socket::SocketErrno errno_;
    Ptr<Ipv4Route> route;
    if (ipv4->GetRoutingProtocol () != 0)
    {
      route = ipv4->GetRoutingProtocol ()->RouteOutput (packet, header, oif, errno_);
    }
    else
    {
      NS_LOG_ERROR ("No IPV4 Routing Protocol");
      route = 0;
    }
    m_downTarget (packet, saddr, daddr, PROT_NUMBER, route);
  }
  else
    NS_FATAL_ERROR("Trying to use Tcp on a node without an Ipv4 interface");

}

/*
  * Maintain tokens
*/
Ptr<NaMPTSocket> 
NaMPTL4Protocol::LookupSocketFromTokenDB(
			uint16_t sport, const Ipv4Address &saddr, 
			uint16_t dport, const Ipv4Address &daddr)
{
  TokenDB_t::const_iterator it;
  for (it = m_tokenDB.begin(); it != m_tokenDB.end(); it++)
  {
    if (it->second->GetPeerPort() != dport ||
		it->second->GetLocalPort() != sport)
		continue;

	Ipv4Address address;
	if (!it->second->GetPeerAddress(address) ||
		address != daddr)
		continue;
	if (!it->second->GetLocalAddress(address) ||
		address != saddr)
		continue;
	return it->second;
  }
  return 0;
}

Ptr<NaMPTSocket> 
NaMPTL4Protocol::LookupToken (uint32_t token)
{
  TokenDB_t::iterator it = m_tokenDB.find (token);
  if (it == m_tokenDB.end() )
  	return 0;
  return it->second;
}

uint32_t 
NaMPTL4Protocol::RegisterToken (const Ptr<NaMPTSocket>& sock)
{
  //srand((uint32_t)this);
  //uint32_t token = rand ();
  uint32_t token = m_tokenSource++;
  while (token == 0 || m_tokenDB.find (token) != m_tokenDB.end() )
  	//token = rand ();
  	token = m_tokenSource++;
  m_tokenDB[token] = sock;
  return token;
}

void 
NaMPTL4Protocol::UnregisterToken (uint32_t token)
{
  TokenDB_t::iterator it = m_tokenDB.find (token);
  if (it != m_tokenDB.end() )
  {
    it->second = 0;
	m_tokenDB.erase (it);
  }
}

void 
NaMPTL4Protocol::EnableDCNNAT(void)
{
  m_enableDCNNAT = true;
  m_pollingEvent.Cancel();
  m_pollingEvent = Simulator::Schedule (Seconds(NATItem::POLLING_INTERVAL), 
  								&NaMPTL4Protocol::DCNNATPollingWoker, this);
}

void
NaMPTL4Protocol::DisableDCNNAT(void)
{
  m_enableDCNNAT = false;
}

uint16_t
NaMPTL4Protocol::ApplyForDCNNAT(const Ipv4Address& desAddr, uint16_t desPort, const Ipv4Address& srcAddr)
{
  if (!m_enableDCNNAT || m_nNAT >= 65535)
    return 0;
  int32_t index = -1;
  Ptr<NATItem> item = 0;
  if (m_freeNAT == -1) // null pointer
  {
    item = CreateObject<NATItem> ();
	index = m_dcnNATTable.size();
	m_dcnNATTable.push_back(item);
  }
  else
  {
    index = m_freeNAT;
	item = m_dcnNATTable.at(index);
	m_freeNAT = item->m_next;
  }
  item->m_addrs[0] = srcAddr.Get();
  item->m_addrs[1] = desAddr.Get();
  item->m_ports[0] = 0;
  item->m_ports[1] = desPort;
  item->m_ttl = NATItem::TTL;
  item->m_flag = 1;
  item->m_next = -1;
  m_nNAT++;
  RecordFlow(srcAddr, desAddr);
  return index+1; // port number
}

void 
NaMPTL4Protocol::DCNNATPollingWoker(void)
{
  for (uint32_t index = 0; index < m_dcnNATTable.size(); index++)
  {
    if (m_dcnNATTable[index]->m_ttl == 0)
	  continue;
	if (m_dcnNATTable[index]->m_flag >= 3)
	  m_dcnNATTable[index]->m_ttl--;
	if (m_dcnNATTable[index]->m_ttl == 0)
	{
	  if (m_nNAT > 0)
	  	m_nNAT--;
	  m_dcnNATTable[index]->m_flag = 0;
	  m_dcnNATTable[index]->m_next = m_freeNAT;
	  m_freeNAT = index;
	}
  }

  // next invoking
  m_pollingEvent.Cancel();
  if (!m_enableDCNNAT && m_nNAT == 0)
  	return ;
  m_pollingEvent = Simulator::Schedule (Seconds(NATItem::POLLING_INTERVAL), 
  								&NaMPTL4Protocol::DCNNATPollingWoker, this);
}

void
NaMPTL4Protocol::RecordFlow(const Ipv4Address& srcAddr, const Ipv4Address& desAddr)
{
/*
Note: Because the communication pattern is unidirectional, we do not increase m_nTo for srdAddr, 
or increase m_nFrom for desAddr. 
*/
  bool bFindSrc = false, bFindDes = false;
  for (uint32_t i = 0; i < m_flowTable.size() && (!bFindSrc||!bFindDes); i++)
  {
    Ptr<FlowItem>& item = m_flowTable[i];
    if (!bFindSrc && item->m_mask.IsMatch(item->m_network, srcAddr) )
    {
      bFindSrc = true;
	  item->m_nFrom++;
	  continue;
    }
	if (!bFindDes && item->m_mask.IsMatch(item->m_network, desAddr) )
	{
	  bFindDes = true;
	  item->m_nTo++;
	  continue;
	}
  }
  if (!bFindSrc)
  {
    Ptr<FlowItem> item = CreateObject<FlowItem>(srcAddr, Ipv4Mask("255.255.255.0") );
	item->m_nFrom = 1;
	m_flowTable.push_back(item);
  }
  if (!bFindDes)
  {
    Ptr<FlowItem> item = CreateObject<FlowItem>(desAddr, Ipv4Mask("255.255.255.0") );
	item->m_nTo = 1;
	m_flowTable.push_back(item);
  }
}

void
NaMPTL4Protocol::EraseFlow(const Ipv4Address& srcAddr, const Ipv4Address& desAddr)
{
  bool bFindSrc = false, bFindDes = false;
  for (uint32_t i = 0; i < m_flowTable.size() && (!bFindSrc||!bFindDes); i++)
  {
    Ptr<FlowItem>& item = m_flowTable[i];
    if (!bFindSrc && item->m_mask.IsMatch(item->m_network, srcAddr) )
    {
      bFindSrc = true;
	  if (item->m_nFrom > 0)
	  	item->m_nFrom--;
	  continue;
    }
	if (!bFindDes && item->m_mask.IsMatch(item->m_network, desAddr) )
	{
	  bFindDes = true;
	  if (item->m_nTo > 0)
	  	item->m_nTo--;
	  continue;
	}
  }
  NS_ASSERT(bFindSrc && bFindDes);
}

uint32_t 
NaMPTL4Protocol::QueryInputFlows(const Ipv4Address& addr, const Ipv4Mask& mask)
{
  uint32_t nFlows = 0;
  for (uint32_t i = 0; i < m_flowTable.size(); i++)
  {
    Ptr<FlowItem>& item = m_flowTable[i];
	if (mask.IsMatch(addr, item->m_network) )
	  nFlows += item->m_nFrom;
  }
  return nFlows;
}

uint32_t 
NaMPTL4Protocol::QueryOutputFlows(const Ipv4Address& addr, const Ipv4Mask& mask)
{
  uint32_t nFlows = 0;
  for (uint32_t i = 0; i < m_flowTable.size(); i++)
  {
    Ptr<FlowItem>& item = m_flowTable[i];
    if (mask.IsMatch(addr, item->m_network) )
      nFlows += item->m_nTo;
  }
  return nFlows;
}

}; // namespace ns3

