/* -*-  Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
/*
 * Author: Yu Cao <caoyu08@csnet1.cs.tsinghua.edu.cn>
 */

#include <stdint.h>
#include <math.h>

#include "ns3/assert.h"
#include "ns3/log.h"
#include "ns3/nstime.h"
#include "ns3/boolean.h"
#include "ns3/object-vector.h"
#include "ns3/packet.h"
#include "ns3/simulator.h"
#include "ns3/packet.h"
#include "ns3/names.h"
#include "ns3/node.h"

#include "nampt-subflow.h"
#include "tcp-header.h"
#include "tcp-socket-base.h"
#include "nampt-l4-protocol.h"
#include "nampt-cc.h"


NS_LOG_COMPONENT_DEFINE ("NaMPTSubflow");

namespace ns3 {

NS_OBJECT_ENSURE_REGISTERED (NaMPTSubflow);


TypeId
NaMPTDSN::GetTypeId(void)
{
  static TypeId tid = TypeId ("ns3::NaMPTDSN")
    .SetParent<Object> ()
    ;
  return tid;
}

NaMPTDSN::NaMPTDSN(uint32_t dataSeq, uint32_t dataLen, uint32_t sfSeq)
	: m_dataSeq(dataSeq),
	  m_dataLen(dataLen),
	  m_sfSeq(sfSeq),
	  m_bAdvertized(false),
	  m_seqTag(0)
{
  m_sfEnd = m_sfSeq + SequenceNumber32(m_dataLen);
}

NaMPTDSN::~NaMPTDSN(void)
{
}


/**************************************
***************************************
***************************************/

TypeId 
NaMPTSubflow::GetTypeId (void)
{
  static TypeId tid = TypeId ("ns3::NaMPTSubflow")
  	.SetParent<NaMPTL4Protocol> ()
  	.AddConstructor<NaMPTSubflow> ()
  	;
  return tid;
}

NaMPTSubflow::NaMPTSubflow ()
	: m_baseSocket (0),
	m_namptSocket (0),
	m_tcp (0),
	m_bAvailable (false),
	m_bBypassSending (false), 
	m_nextRescheduling (0),
	m_sampleRTT (Seconds(60.0) ),
	m_baseRTT (Seconds(60.0) ),
	m_begSeq (0),
	m_minRTT (Seconds(60.0) ),
	m_cnRTT (0),
	m_sumRTT (0),
	m_rounds (0), 
	m_cnAcks (0), 
	m_refWin (0),
	m_gamma (2),
	m_weight (1.0), 
	m_equilibrium (0), 
	m_instantRate (0), 
	m_minQueueDelay (0), 
	m_nCE (0), 
	m_nECE (0), 
	m_marking (0), 
	m_cwr (0), 
	m_cwrHighSeq (0), 
	m_incCum (0.0), 
	m_cnIncreased (0),
	m_cnDecreased (0),
	m_cnUnchanged (0), 
	m_rcvPkts (0),
	m_rcvBytes (0),
	m_goodBytes (0),
	m_signalBytes (0),
	m_cnTimeout (0),
	m_cnFastRetrans (0),
	m_cnRescheduling (0),
	m_byteRescheduling (0),
	m_sndPackets (1) 
{
  NS_LOG_FUNCTION_NOARGS ();
  NS_LOG_LOGIC("Made a NaMPTSubflow "<<this);

  // record creation time
  m_startTime = Simulator::Now();

  // setup RTT estimator
  ObjectFactory rttFactory;
  rttFactory.SetTypeId (RttMeanDeviation::GetTypeId () );
  m_rtt = rttFactory.Create<RttMeanDeviation> ();
  m_rtt->Reset();
  m_rtt->Gain(0.125);
  
}

NaMPTSubflow::~NaMPTSubflow ()
{
  NS_LOG_FUNCTION_NOARGS ();
  m_baseSocket = 0;
  m_namptSocket = 0;
  m_tcp = 0;
  m_rtt = 0;
  
  std::deque<Ptr<NaMPTSignal> >::iterator it;
  for (it = m_signalQueue.begin(); it != m_signalQueue.end(); it++)
  	(*it) = 0;
  m_signalQueue.clear();

  DSNDB_t::iterator dit;
  for (dit = m_localDSN.begin(); dit != m_localDSN.end(); dit++)
  	dit->second = 0;
  m_localDSN.clear();

  for (dit = m_peerDSN.begin(); dit != m_peerDSN.end(); dit++)
  	dit->second = 0;
  m_peerDSN.clear();
  
}

/*
  * filter signals before deliver packet to m_baseSocket
*/
void 
NaMPTSubflow::ForwardUp (Ptr<Packet> packet, Ipv4Header header, 
					uint16_t port, Ptr<Ipv4Interface> incomingInterface)
{
  NaMPTHeader namptHeader;
  packet->RemoveHeader(namptHeader);

  // syn+ack packet should piggyback join signal
  if (m_baseSocket->m_state == SYN_SENT)
  {
	NS_ASSERT(namptHeader.GetFlags() == (TcpHeader::SYN | TcpHeader::ACK) );
	Ptr<NaMPTSignal> sig = namptHeader.FindSignal(NAMPT_JOIN, true);
	NS_ASSERT(sig != 0 && sig->sigData.join.peerToken == m_namptSocket->GetLocalToken() );
  }
  else if (m_baseSocket->m_state == SYN_RCVD && 
  	(namptHeader.GetFlags() & TcpHeader::ACK) )
  {
    NAMPT_LOG("Subflow is established: " << m_peerAddress << " ---> " << m_localAddress);
	m_bAvailable = true;
	if (m_bBypassSending) // the first subflow
    {
      m_bBypassSending = false;
	  // info available addresses
	  Simulator::ScheduleNow(&NaMPTSocket::AdvertizeAddresses, m_namptSocket, Ptr<NaMPTSubflow>(this) );
    }
  }
  
  // deal with signals
  //if (NaMPTSocket::HaveSignals(packet) )
  std::vector<Ptr<NaMPTSignal> >::iterator it;
  Ptr<NaMPTSignal> sig = namptHeader.GetFirstSignal(it);
  while (sig != 0)
  {
    m_signalBytes += (2 + sig->sigLen);
	
    switch (sig->sigType)
    {
      case NAMPT_ADDR:
        Simulator::ScheduleNow(&NaMPTSocket::TryPeerAddress, m_namptSocket, 
      							sig->sigData.addr.ipv4Addr);
        break;
  	  case NAMPT_DSN:
  	    ReceiveDSN(sig->sigData.dsn.dataSeq, sig->sigData.dsn.dataLen, sig->sigData.dsn.sfSeq);
  	    break;
      default:
	  	// duplicated MPC or JOIN
	  	NS_ASSERT(sig->sigType == NAMPT_MPC ||sig->sigType == NAMPT_JOIN);
    }
  
    sig = namptHeader.GetNextSignal(it);
  }

  // statistics
  m_rcvPkts ++;
  m_rcvBytes += packet->GetSize();   

  // RTT estimation
  if (namptHeader.GetFlags () & TcpHeader::ACK)
  {
    m_sampleRTT = m_rtt->AckSeq (namptHeader.GetAckNumber () );
	if (m_sampleRTT != Seconds (0.0) )
    {
      m_baseRTT = Min(m_baseRTT, m_sampleRTT);
	  m_cnRTT ++;
	  m_sumRTT += m_sampleRTT.GetMicroSeconds();
	  m_minRTT = Min(m_minRTT, m_sampleRTT);
    }
    
  }

  // process ecn
  if (header.GetEcn() == Ipv4Header::CE)
  {
    // when getting a CE, the receiver records it and then returns it to the sender using ECN Echo.
    if (packet->GetSize() > 0) // not ack
      m_nCE ++;
	header.SetEcn(Ipv4Header::NotECT);
  }
  uint8_t flags = namptHeader.GetFlags();
  if (flags & TcpHeader::ECE)
  {
    m_nECE ++;
	namptHeader.SetFlags(flags & (~TcpHeader::ECE));
	// enter CWR state
	if (m_cwr == 1)
	{
	  m_cwr = 2;
	  m_cwrHighSeq = m_baseSocket->m_nextTxSequence;
	  m_enterCWR (Ptr<NaMPTSubflow>(this), DynamicCast<NaMPTCC>(m_baseSocket));
	}
  }

  // forwarding
  namptHeader.RemoveAllSignals();
  packet->AddHeader(namptHeader);
  m_baseSocket->ForwardUp(packet, header, port, incomingInterface);

}

/*
  * When endPoint is destroyed, this function will be invoked.
*/
void NaMPTSubflow::Destroy (void)
{
  NS_ASSERT(m_baseSocket->m_tcp == this);
  m_baseSocket->m_tcp = m_tcp; // correct it.
  m_baseSocket->Destroy();
}

Ptr<NaMPTSignal> 
NaMPTSubflow::SendSignal(TcpOptionType_t sigType)
{
  Ptr<NaMPTSignal> signal = CreateObject<NaMPTSignal>(sigType);
  m_signalQueue.push_back(signal);
  return signal;
}

/*
  * Invoked by TcpSocketBase, see: m_tcp->SendPacket();
  * Mimic TcpL4Protocol
*/
void 
NaMPTSubflow::SendPacket (Ptr<Packet> packet, const TcpHeader &outgoing,
                               Ipv4Address saddr, Ipv4Address daddr, Ptr<NetDevice> oif)
{
  // for syn+ack+mpc
  if (m_bBypassSending)
  {
    m_tcp->SendPacket(packet, outgoing, saddr, daddr, oif);
	return ;
  }

  // add nampt-related content.
  NaMPTHeader namptHeader(outgoing);

  // piggyback signals
  if (m_signalQueue.size() != 0)
  {
    namptHeader.AddSignal(m_signalQueue.front() );
	m_signalQueue.pop_front();
  }

  // syn packet means joining a connection
  if (namptHeader.GetFlags() & TcpHeader::SYN)
  {
    Ptr<NaMPTSignal> sig = namptHeader.AddSignal(NAMPT_JOIN);
	sig->sigData.join.peerToken = m_namptSocket->GetPeerToken();
  }

  // piggyback DSN
  uint32_t dataLen = packet->GetSize();
  if (dataLen > 0)
  {
    // data: [sfSeq, sfEnd)
  	SequenceNumber32 sfSeq = namptHeader.GetSequenceNumber();
	SequenceNumber32 sfEnd = sfSeq + SequenceNumber32(dataLen);
	/*
	                 0        1            2           end()
	    case1:   |----|------|------|            return 2
	                                  ^
	    case2:   |----|------|------|            return end
	                                                 ^
	    case3:   |----|------|------|            return  1
	                           ^
	*/
    DSNDB_t::iterator it = m_localDSN.lower_bound(sfSeq);
	if (it == m_localDSN.end() || sfSeq < it->first)
	  it--;
	NS_ASSERT(it->second->m_sfSeq <= sfSeq && sfSeq < it->second->m_sfEnd);

	while(true)
	{
	  if (!it->second->m_bAdvertized || // has not sent
	  	  sfSeq < it->second->m_seqTag )// retransmission
	  {
	    Ptr<NaMPTSignal> sig = namptHeader.AddSignal(NAMPT_DSN);
		sig->sigData.dsn.dataSeq = (uint32_t)it->second->m_dataSeq.GetValue();
		sig->sigData.dsn.dataLen = it->second->m_dataLen;
		sig->sigData.dsn.sfSeq = (uint32_t)it->second->m_sfSeq.GetValue();

		// tag
		it->second->m_bAdvertized = true;
		it->second->m_seqTag = sfEnd;
	  }
	  	
	  if (sfEnd <= it->second->m_sfEnd)
	    break;
	  it++;
	}
	
  }

  // notify the RTT
  m_rtt->SentSeq (namptHeader.GetSequenceNumber(), packet->GetSize() );

  // notify the sender of CE
  if(m_nCE > 0)
  {
    uint8_t flags = namptHeader.GetFlags();
    namptHeader.SetFlags(flags | TcpHeader::ECE);
	m_nCE --;
  }

  // forwarding...
  m_sndPackets += 1;
  m_tcp->SendPacket(packet, namptHeader, saddr, daddr, oif);
  
}

/*
  * Invoked by TcpSocketBase, see: m_tcp->DeAllocate (m_endPoint);
  * Mimic TcpL4Protocol
*/
void 
NaMPTSubflow::DeAllocate (Ipv4EndPoint *endPoint)
{
  NS_ASSERT(m_baseSocket->m_tcp == this);
  m_baseSocket->m_tcp = m_tcp; // correct it.
  m_tcp->DeAllocate(endPoint);
}

/*
  * invoked by Scheduler
*/
uint32_t 
NaMPTSubflow::Send(Ptr<Packet> p, const SequenceNumber32& dataSeq)
{
  uint32_t	totalLen = p->GetSize();
  NS_ASSERT(totalLen > 0);

  uint32_t perLen = totalLen;
  while (perLen > NaMPTDSN::MaxSize) // max dsn len
  	perLen /= 2;

  SequenceNumber32 dSeq = dataSeq;
  SequenceNumber32 sSeq = m_baseSocket->m_txBuffer.TailSequence();

  // validate
  if (m_localDSN.size() > 0)
  {
    DSNDB_t::iterator it = m_localDSN.end();
    it--;
	NS_ASSERT (it->second->m_sfEnd == sSeq);

	// remove expired dsn
	SequenceNumber32 highestAck = m_baseSocket->m_txBuffer.HeadSequence();
	it = m_localDSN.begin();
	while(it != m_localDSN.end() )
	{
	  if (it->second->m_sfEnd <= highestAck   // this data with dsn has been received
	  	&& it->second->m_seqTag <= highestAck)// note the case: one pkt including two or more dsns.
	  {
	    DSNDB_t::iterator tmpIt = it++;
	    NS_ASSERT(tmpIt->second->m_bAdvertized == true);
		tmpIt->second = 0;
		m_localDSN.erase(tmpIt);
		continue;
	  }
	  break;
	}
  }

  // add new dsn
  while (totalLen != 0)
  {
    uint32_t dLen = std::min(perLen, totalLen);
    m_localDSN[sSeq] = CreateObject<NaMPTDSN>(dSeq.GetValue(), dLen, sSeq.GetValue() );
	dSeq += dLen;
	sSeq += dLen;
	totalLen -= dLen;
  }

  // forwarding ...
  return m_baseSocket->Send(p, 0);
}

uint32_t 
NaMPTSubflow::PeekData(SequenceNumber32& dataSeq)
{
  uint32_t dataLen = m_baseSocket->m_rxBuffer.Available();
  if (dataLen == 0) // no in-sequence data
    return 0;
  
  SequenceNumber32 lowestInSequence = m_baseSocket->m_rxBuffer.NextRxSequence();
  if (m_baseSocket->m_rxBuffer.Finished() )
  	lowestInSequence--;
  lowestInSequence -= dataLen;

  NS_ASSERT(m_peerDSN.size() > 0);
  DSNDB_t::const_iterator it = m_peerDSN.upper_bound(lowestInSequence);
  it--;
  NS_ASSERT(lowestInSequence >= it->first && 
  	        lowestInSequence < it->second->m_sfEnd );
  
  uint32_t offSet = lowestInSequence - it->first;
  NS_ASSERT(offSet < NaMPTDSN::MaxSize && offSet < it->second->m_dataLen);

  dataSeq = it->second->m_dataSeq + offSet; // output dataSeq
  dataLen = std::min (dataLen, it->second->m_dataLen - offSet);
  NS_ASSERT (dataLen != 0);

  return dataLen;
}

Ptr<Packet> 
NaMPTSubflow::Recv(uint32_t maxSize, SequenceNumber32& dataSeq)
{
  uint32_t availSize = PeekData(dataSeq); // output dataSeq
  
  if (availSize == 0 || maxSize == 0)
  	return 0;

  maxSize = std::min(maxSize, availSize);
  return m_baseSocket->Recv(maxSize, 0);
}

void
NaMPTSubflow::Close(void)
{
  m_bAvailable = false;
  
  // close means I never send data
  m_baseSocket->SetSendCallback(MakeNullCallback<void, Ptr<Socket>, uint32_t> ());

  // notify NaMPTSocket that this subflow is closed.
  m_baseSocket->SetCloseCallbacks(MakeCallback(&NaMPTSocket::DoSubflowNormalClose, m_namptSocket),
  							MakeCallback(&NaMPTSocket::DoSubflowErrorClose, m_namptSocket) );

  // trivial process
  m_baseSocket->Close();
}

void
NaMPTSubflow::OutputInfo(void)
{
	if (m_goodBytes > 0)
	// Receiver show statistics
	//if (!m_baseSocket->m_shutdownRecv)
	  NAMPT_LOG("Subflow " << m_localAddress << 
				//" RcvPkts:" << m_rcvPkts << 
				//" RcvBytes:" << m_rcvBytes << 
				//" Delivery:" << m_goodBytes << 
				//" SignalBytes:" << m_signalBytes << 
				" TP:" << m_rcvBytes*8/(Simulator::Now().GetSeconds()-m_startTime.GetSeconds() )/1000000 << "Mbps" << 
				" GP:" << m_goodBytes*8/(Simulator::Now().GetSeconds()-m_startTime.GetSeconds() )/1000000 << "Mbps" << 
				" SI:" << m_signalBytes*8/(Simulator::Now().GetSeconds()-m_startTime.GetSeconds() )/1000000 << "Mbps");
	else
	// Sender show statistics
	//if (!m_baseSocket->m_shutdownSend)
	  NAMPT_LOG("Subflow " << m_localAddress << 
				" TO:" << m_cnTimeout << 
				" FR:" << m_cnFastRetrans << 
				" RN:" << m_cnRescheduling << 
				" RB:" << m_byteRescheduling << 
				" Drop:" << (float)(m_cnTimeout+m_cnFastRetrans)/(float)m_sndPackets << 
				" BaseRTT:" << m_baseRTT.GetMicroSeconds() << "us" << 
				" SRTT:" << m_rtt->GetCurrentEstimate().GetMicroSeconds() << "us/" << m_baseSocket->m_rtt->GetCurrentEstimate().GetMicroSeconds() << "us" << 
				" Inc:" << m_cnIncreased << 
				" Dec:" << m_cnDecreased << 
				" Unc:" << m_cnUnchanged );
}

void 
NaMPTSubflow::ReceiveDSN(uint32_t dataSeq, uint32_t dataLen, uint32_t sfSeq)
{
  SequenceNumber32 sSeq(sfSeq);
  
  // validate
  if (m_peerDSN.size() > 0)
  {
    DSNDB_t::iterator it = m_peerDSN.lower_bound(sSeq);
	if (it == m_peerDSN.end() )
	{
	  it--;
	  NS_ASSERT(it->second->m_sfEnd <= sSeq);
	}
	else
	{
	  if (it->first == sSeq)
	  {
	    NS_ASSERT(it->second->m_dataSeq == SequenceNumber32(dataSeq) &&
			      it->second->m_dataLen == dataLen );
	  }
	  else
	  {
	    NS_ASSERT(sSeq + SequenceNumber32(dataLen) <= it->first);
	  }
	}
	
    // remove expired dsn
    SequenceNumber32 lowestInSequence = m_baseSocket->m_rxBuffer.NextRxSequence();
	if (m_baseSocket->m_rxBuffer.Finished() )
	  lowestInSequence--;
    lowestInSequence -= m_baseSocket->m_rxBuffer.Available();

	it = m_peerDSN.begin();
    while(it != m_peerDSN.end() )
    {
      if (it->second->m_sfEnd <= lowestInSequence)
      {
        DSNDB_t::iterator tmpIt = it++;
        tmpIt->second = 0;
  	    m_peerDSN.erase(tmpIt);
  	    continue;
      }
  	  break;
    }
  }

  // add new one
  m_peerDSN[sSeq] = CreateObject<NaMPTDSN>(dataSeq, dataLen, sfSeq);

}

uint32_t 
NaMPTSubflow::FreeSizeOfIngoingQueue(void)
{
  uint32_t freeSize = m_baseSocket->m_rxBuffer.MaxBufferSize() - m_baseSocket->m_rxBuffer.Size();
  NS_ASSERT (freeSize <= m_baseSocket->m_rxBuffer.MaxBufferSize() );
  return freeSize;
}

uint32_t 
NaMPTSubflow::FreeSizeOfOutgoingQueue(void)
{
  return m_baseSocket->GetTxAvailable();
}

uint32_t 
NaMPTSubflow::BacklogOfOutgoingQueue(void)
{
  return m_baseSocket->m_txBuffer.SizeFromSequence(m_baseSocket->m_nextTxSequence);
}

uint32_t 
NaMPTSubflow::FlightingOfOutgoingQueue(void)
{
  return m_baseSocket->m_highTxMark.Get() - m_baseSocket->m_txBuffer.HeadSequence();
}

uint32_t 
NaMPTSubflow::FreeSizeOfSendingWindow(void)
{
  return m_baseSocket->AvailableWindow();
}

uint32_t 
NaMPTSubflow::SegmentSize(void)
{
  return m_baseSocket->m_segmentSize;
}

uint32_t 
NaMPTSubflow::SegmentRound(uint32_t size)
{
  size /= m_baseSocket->m_segmentSize;
  return size * m_baseSocket->m_segmentSize;
}

}; // namespace ns3

