/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
/*
*
*	Author: Yu Cao <caoyu08@csnet1.cs.tsinghua.edu.cn>
 */


#include <ctype.h>
#include <iostream>
#include <iomanip>
#include <fstream>
#include <string>
#include <cassert>

#include "ns3/core-module.h"
#include "ns3/ipv4-global-routing-helper.h"
#include "ns3/nampt-l4-protocol.h"

#include "ns3/applications-module.h"
#include "ns3/network-module.h"
#include "ns3/internet-module.h"
#include "ns3/point-to-point-module.h"

/*
 Note: you should first comment out the Macro: 
 #define NAMPT_LARGE_SCALE, 
 which is located in "nampt-l4-protocol.h"
*/


// Parameters
#define APPNAME				"RateCompensation"
static const uint32_t		sndBufSize	= 400000;
static const uint32_t		rcvBufSize	= 800000;
static const uint32_t		queueSize	= 100;
static const uint32_t		markLine		= 20;
static const uint32_t		mygamma		= 1;
static const uint32_t		mybeta		= 4;
static const double			minRTO		= 0.2;
static const uint32_t		pktSize		= 1400;
ns3::NaMPTAlgo_t			ccAlgo			= ns3::NAMPT_CC_XMP;
ns3::NaMPTAlgo_t			schedulingAlgo	= ns3::NAMPT_SOD;
ns3::NaMPTAlgo_t			assemblingAlgo	= ns3::NAMPT_NA;
static const bool			bUpSingle = false;
static const bool			bDownSingle = true;
static const char *			bandWidth[] = 
	{"0.8Gbps", "1.2Gbps", "2Gbps", "1.5Gbps", "0.5Gbps"}; // others are 10Gbps
// 500M * 350us / (8 * 1500) = 15 pkts
// 2000M * 350us / (8 * 1500) = 58 ~ 60 pkts
	

static double				simTime		= 65; // second
static const double			beginTime[] = 
	{0, 0,  5,  10, 15, 20};
static const double			endTime[] = 
	{0, 65, 65, 65, 65, 65};


static const bool			bLoss = false;
static const double			lossRate[] = 
	{};




using namespace ns3;

NS_LOG_COMPONENT_DEFINE (APPNAME);

class BulkApplication : public Object {
public:
  BulkApplication (std::string nodeName);
  virtual ~BulkApplication (void);
  void SetAttribute (std::string name, const AttributeValue &value);
  void Run(Ipv4Address target, uint16_t port, Time const& start, Time const& stop);

protected:
  void FillBuffer (void);
  void WriteUntilBufferFull (Ptr<Socket> localSocket, uint32_t txAvailable);
  void DataHasSent (Ptr<Socket> localSocket, uint32_t txSize);
  void ConnectionSucceeds (Ptr<Socket> localSocket);
  void ConnectionFails (Ptr<Socket> localSocket);
  void NormalClose (Ptr<Socket> localSocket);
  void ErrorClose (Ptr<Socket> localSocket);
  void StartTransmission (Ptr<Socket> localSocket, Ipv4Address servAddress, uint16_t servPort);
  void StopTransmission (Ptr<Socket> localSocket);

protected:
  static const uint32_t	bufSize = 1000000;
  uint8_t 				appBuf[bufSize];
  uint32_t				offSet;
  Time					startTime;
  Time					stopTime;
  Ptr<Node>				node;
  Ptr<Socket>			localSocket;
};

BulkApplication::BulkApplication (std::string nodeName)
{
  node = Names::Find<Node>(nodeName);
  FillBuffer();
  localSocket = Socket::CreateSocket(node, TcpSocketFactory::GetTypeId () );
  localSocket->Bind (); // It is not necessary for sender to invoke bind, because connect will do it.
  localSocket->ShutdownRecv ();
}

BulkApplication::~BulkApplication (void)
{
  localSocket = 0;
  node = 0;
}

void 
BulkApplication::SetAttribute (std::string name, const AttributeValue &value)
{
  localSocket->SetAttribute(name, value);
}

void 
BulkApplication::Run(Ipv4Address target, uint16_t port, Time const& start, Time const& stop)
{
  startTime = start;
  stopTime = stop;
  
  Simulator::Schedule(startTime, &BulkApplication::StartTransmission, this, 
  						localSocket, target, port);
}

void 
BulkApplication::FillBuffer (void)
{
	/*for(uint32_t i = 0; i < bufSize; i++)
	{
		appBuf[i] = toascii (97 + i % 26);
	}*/
	offSet = 0;
}

void 
BulkApplication::WriteUntilBufferFull (Ptr<Socket> localSocket, uint32_t txAvailable)
{
	// Note: txAvailable is inaccurate
	txAvailable = localSocket->GetTxAvailable ();
	if (txAvailable <= 0)
		return ;

	NS_LOG_LOGIC ("Application sends " << txAvailable << " bytes at " << Simulator::Now().GetSeconds() );

	// Fill out tx buffer  
	uint8_t *buff = new uint8_t [txAvailable];
	uint32_t left = txAvailable;
	uint32_t i = 0;
	if (buff == 0) return ;
	while (left > 0)
	{
		uint32_t toWrite = bufSize - offSet;
		toWrite = std::min (toWrite, left);
		//memcpy (&buff[i], &appBuf[offSet], toWrite);

		offSet += toWrite;
		left -= toWrite;
		i += toWrite;
        
		if (offSet >= bufSize)
			FillBuffer ();
	}
	localSocket->Send (buff, txAvailable, 0);
	delete []buff;
}

void 
BulkApplication::StopTransmission (Ptr<Socket> localSocket)
{
	NS_LOG_LOGIC ("StopTransmission at " << Simulator::Now().GetSeconds() );
	localSocket->Close ();
}

void 
BulkApplication::ConnectionSucceeds (Ptr<Socket> localSocket)
{
	NS_LOG_LOGIC ("Connection succeeds at " << Simulator::Now().GetSeconds() );
	
	Time cur = Simulator::Now();
	if (stopTime < cur)
	  stopTime = cur;
	
	Simulator::Schedule(stopTime-startTime, &BulkApplication::StopTransmission, this, localSocket);
	WriteUntilBufferFull (localSocket, localSocket->GetTxAvailable() );
}

void 
BulkApplication::ConnectionFails (Ptr<Socket> localSocket)
{
	NS_LOG_LOGIC ("Connection fails at " << Simulator::Now().GetSeconds() );
	localSocket->Close ();
}

void 
BulkApplication::DataHasSent (Ptr<Socket> localSocket, uint32_t txSize)
{
	NS_LOG_LOGIC (txSize << " bytes has been sent at " << Simulator::Now().GetSeconds() );
}

void 
BulkApplication::NormalClose (Ptr<Socket> localSocket)
{
	NS_LOG_LOGIC ("Peer has been gracefully closed at " << Simulator::Now().GetSeconds() );
	localSocket->Close ();
}

void 
BulkApplication::ErrorClose (Ptr<Socket> localSocket)
{
	NS_LOG_LOGIC ("Errors break transmission down at " << Simulator::Now().GetSeconds() );
	localSocket->Close ();
}

void 
BulkApplication::StartTransmission (Ptr<Socket> localSocket, Ipv4Address servAddress, uint16_t servPort)
{
	NS_LOG_LOGIC ("Starting transmission at time " <<  Simulator::Now ().GetSeconds () );

	localSocket->SetConnectCallback (MakeCallback (&BulkApplication::ConnectionSucceeds,this), MakeCallback (&BulkApplication::ConnectionFails,this) ); 
	localSocket->SetSendCallback (MakeCallback (&BulkApplication::WriteUntilBufferFull, this) );
	localSocket->SetDataSentCallback (MakeCallback (&BulkApplication::DataHasSent, this) );
	localSocket->SetCloseCallbacks (MakeCallback (&BulkApplication::NormalClose,this), MakeCallback(&BulkApplication::ErrorClose,this) );

	if (localSocket->Connect (InetSocketAddress (servAddress, servPort) ) != 0)
		NS_LOG_ERROR ("Connect fails!");
  
}

void DoSetLossRate(Ptr<PointToPointNetDevice> ppp, float loss)
{
	/*
	Ptr<RateErrorModel> em = CreateObjectWithAttributes<RateErrorModel> (
			"ErrorUnit", StringValue ("EU_PKT"),
			"RanVar",	 RandomVariableValue (UniformVariable (0.0, 1.0)),
			"ErrorRate", DoubleValue (loss) );
	ppp->SetAttribute ("ReceiveErrorModel", PointerValue (em) );
	*/
	Ptr<RateErrorModel> em = CreateObjectWithAttributes<RateErrorModel>("RanVar", RandomVariableValue (UniformVariable (0.0, 1.0)));
	em->SetAttribute ("ErrorRate", DoubleValue (loss));
	em->SetAttribute ("ErrorUnit", StringValue ("ERROR_UNIT_PACKET"));
	ppp->SetAttribute ("ReceiveErrorModel", PointerValue (em) );
}

void SetLinkLossRate(const NetDeviceContainer& dev, float forwardLoss, float backwardLoss, Time start, Time stop)
{
	Ptr<PointToPointNetDevice> ppp0 = dev.Get(0)->GetObject<PointToPointNetDevice>();
	Ptr<PointToPointNetDevice> ppp1 = dev.Get(1)->GetObject<PointToPointNetDevice>();

	// save old loss rates
	PointerValue pv;
	ppp1->GetAttribute ("ReceiveErrorModel", pv);
	Ptr<RateErrorModel> fem = pv.Get<RateErrorModel>();
	ppp0->GetAttribute ("ReceiveErrorModel", pv);
	Ptr<RateErrorModel> bem = pv.Get<RateErrorModel>();

	// new loss rates
	Simulator::Schedule(start, &DoSetLossRate, ppp1, forwardLoss);
	Simulator::Schedule(start, &DoSetLossRate, ppp0, backwardLoss);

	// recover loss rates
	if (fem != 0)
		Simulator::Schedule(stop, &DoSetLossRate, ppp1, fem->GetRate() );
	if (bem != 0)
		Simulator::Schedule(stop, &DoSetLossRate, ppp0, bem->GetRate() );
	
}


void SetBackgroundFlows(std::string ingress, std::string outgress, uint32_t flows, Time start, Time stop )
{
	static uint32_t addr = 0;
	NS_ASSERT(addr <= 255);

	NodeContainer p, q;
	NetDeviceContainer dev;
	PointToPointHelper p2p;
	InternetStackHelper stack;
	Ipv4AddressHelper address;

	p.Create(flows);
	q.Create(flows);
	stack.SetTcp("ns3::NaMPTL4Protocol");
	stack.Install(p);
	stack.Install(q);
	
	p2p.SetDeviceAttribute ("DataRate", StringValue("10Gbps") );
	p2p.SetChannelAttribute("Delay", StringValue("50us") );

	PacketSinkHelper sink ("ns3::TcpSocketFactory", 
							InetSocketAddress (Ipv4Address::GetAny (), 8000) );

	BulkApplication* bulkApp = 0;
	char buf[1024];
	for (uint32_t i = 0; i < flows; i++, addr++)
    {
      snprintf(buf, 1024, "p%u", addr);
	  Names::Add (buf, p.Get(i) );
      dev = p2p.Install(p.Get(i), ingress);
	  snprintf(buf, 1024, "166.111.%u.0", addr);
	  address.SetBase(buf, "255.255.255.0");
	  address.Assign(dev);

	  snprintf(buf, 1024, "q%u", addr);
	  Names::Add (buf, q.Get(i) );
	  dev = p2p.Install(outgress, q.Get(i) );
	  snprintf(buf, 1024, "59.66.%u.0", addr);
	  address.SetBase(buf, "255.255.255.0");
	  address.Assign(dev);

	  ApplicationContainer apps = sink.Install (q.Get(i) );
	  apps.Start (start);
	  //apps.Stop (stop);
	  
      bulkApp = new BulkApplication(Names::FindName(p.Get(i)) );
	  snprintf(buf, 1024, "59.66.%u.2", addr);
      bulkApp->Run(Ipv4Address(buf), 8000, start, stop);
    }
	
	return ;
}



int main (int argc, char *argv[])
{
	/* Enable log components */
	//LogComponentEnable ("PacketSink", LOG_LEVEL_ALL);
	//LogComponentEnableAll (LOG_LEVEL_ALL);

	/* Parse parameters */
	CommandLine cmd;
	cmd.AddValue ("st", "Simulation Time", simTime);
	cmd.Parse (argc, argv);

	/* Global configurations */
	Config::SetDefault ("ns3::TcpL4Protocol::SocketType", StringValue ("ns3::TcpReno") );
	Config::SetDefault ("ns3::TcpSocket::SegmentSize", UintegerValue (pktSize) );
	Config::SetDefault ("ns3::TcpSocket::DelAckCount", UintegerValue(1) );
	Config::SetDefault ("ns3::TcpSocket::SlowStartThreshold", UintegerValue(100) );
	Config::SetDefault ("ns3::TcpSocket::InitialCwnd", UintegerValue(2) );
	Config::SetDefault ("ns3::TcpSocket::RcvBufSize", UintegerValue(rcvBufSize) );
	Config::SetDefault ("ns3::TcpSocket::SndBufSize", UintegerValue(sndBufSize) );
	Config::SetDefault ("ns3::NaMPTSocket::CongestionAlgo", UintegerValue(ccAlgo) );
    Config::SetDefault ("ns3::NaMPTSocket::SchedulingAlgo", UintegerValue(schedulingAlgo) );
	Config::SetDefault ("ns3::NaMPTSocket::AssemblingAlgo", UintegerValue(assemblingAlgo) );
	Config::SetDefault ("ns3::NaMPTSocket::gamma", UintegerValue(mygamma) );
	Config::SetDefault ("ns3::NaMPTSocket::beta", UintegerValue(mybeta) );
	Config::SetDefault ("ns3::DropTailQueue::MaxPackets", UintegerValue(queueSize) );
	Config::SetDefault ("ns3::DropTailQueue::MarkLine", UintegerValue(markLine) );
	Config::SetDefault ("ns3::RttEstimator::MinRTO", TimeValue(Seconds(minRTO)) );
	
	/* Construct topology */
	uint32_t links = sizeof(bandWidth)/sizeof(char *);
	NodeContainer s, r, v, d;
	NetDeviceContainer dev;
	PointToPointHelper p2p;
	InternetStackHelper stack;
	Ipv4AddressHelper address;
	uint32_t route = 0;
	static const uint32_t buflen = 1024;
	char buf[buflen];


	// nodes and stacks
	r.Create(links);
	v.Create(links);
	s.Create(links + 1);
	d.Create(links + 1);
	
	stack.SetTcp ("ns3::NaMPTL4Protocol");
	stack.Install(s);
	stack.Install(d);
	stack.Install(r);
	stack.Install(v);

	// bottlenecks
	std::map<uint32_t, NetDeviceContainer> bottlenecks;
	for (uint32_t i = 0; i < links; i ++)
    {
      snprintf(buf, buflen, "r%u", i);
      Names::Add (buf, r.Get(i) );
	  snprintf(buf, buflen, "v%u", i);
	  Names::Add (buf, v.Get(i) );

	  p2p.SetDeviceAttribute ("DataRate", StringValue(bandWidth[i]) );
	  p2p.SetChannelAttribute("Delay", StringValue("50us") );
	  dev = p2p.Install (r.Get(i), v.Get(i) );
	  snprintf(buf, buflen, "192.168.%u.0", i);
	  address.SetBase (buf, "255.255.255.0");
	  address.Assign(dev);
	  bottlenecks[i] = dev;

	  // drop on v_i
	  if (bLoss)
	  {
	    Ptr<RateErrorModel> em = CreateObjectWithAttributes<RateErrorModel> (
				"ErrorUnit", StringValue ("EU_PKT"),
				"RanVar",    RandomVariableValue (UniformVariable (0.0, 1.0)),
				"ErrorRate", DoubleValue (lossRate[i]) );
	    dev.Get(1)->SetAttribute ("ReceiveErrorModel", PointerValue (em) );
	  }
    }

	// Other links
	//if (bUpSingle)
    {
      Names::Add ("s0", s.Get(0) );
	  Names::Add ("d0", d.Get(0) );
	  p2p.SetDeviceAttribute ("DataRate", StringValue("10Gbps") );
	  p2p.SetChannelAttribute("Delay", StringValue("50us") );
	  
	  dev = p2p.Install ("s0", "r0");
	  address.SetBase ("10.0.0.0", "255.255.255.0");
	  address.Assign(dev);
	  dev = p2p.Install ("v0", "d0");
	  address.SetBase ("10.0.1.0", "255.255.255.0");
	  address.Assign(dev);
	  route++;
    }
	
	for (uint32_t i = 1; i < links; i ++) //  indexed by s or d
    {
    	snprintf(buf, buflen, "s%u", i);
		Names::Add (buf, s.Get(i) );
		snprintf(buf, buflen, "d%u", i);
		Names::Add (buf, d.Get(i) );
		p2p.SetDeviceAttribute ("DataRate", StringValue("10Gbps") );
		p2p.SetChannelAttribute("Delay", StringValue("50us") );
		
		dev = p2p.Install (s.Get(i), r.Get(i-1) );
		snprintf(buf, buflen, "10.%u.0.0", route);
		address.SetBase (buf, "255.255.255.0");
		address.Assign(dev);
		dev = p2p.Install (v.Get(i-1), d.Get(i) );
		snprintf(buf, buflen, "10.%u.1.0", route);
		address.SetBase (buf, "255.255.255.0");
		address.Assign(dev);
		route ++;

        dev = p2p.Install (s.Get(i), r.Get(i) );
        snprintf(buf, buflen, "10.%u.0.0", route);
        address.SetBase (buf, "255.255.255.0");
        address.Assign(dev);
        dev = p2p.Install (v.Get(i), d.Get(i) );
        snprintf(buf, buflen, "10.%u.1.0", route);
        address.SetBase (buf, "255.255.255.0");
        address.Assign(dev);
        route ++;
    }

	//if (bDownSingle)
    {
      snprintf(buf, buflen, "s%u", links);
      Names::Add (buf, s.Get(links) );
	  snprintf(buf, buflen, "d%u", links);
	  Names::Add (buf, d.Get(links) );
	  p2p.SetDeviceAttribute ("DataRate", StringValue("10Gbps") );
	  p2p.SetChannelAttribute("Delay", StringValue("50us") );
	  
	  dev = p2p.Install (s.Get(links), r.Get(links-1) );
	  snprintf(buf, buflen, "10.%u.0.0", route);
	  address.SetBase (buf, "255.255.255.0");
	  address.Assign(dev);
	  dev = p2p.Install (v.Get(links-1), d.Get(links) );
	  snprintf(buf, buflen, "10.%u.1.0", route);
	  address.SetBase (buf, "255.255.255.0");
	  address.Assign(dev);
	  route++;
    }

	// Torus
	dev = p2p.Install (s.Get(links), r.Get(0) );
	snprintf(buf, buflen, "10.%u.0.0", route);
	address.SetBase (buf, "255.255.255.0");
	address.Assign(dev);
	dev = p2p.Install (v.Get(0), d.Get(links) );
	snprintf(buf, buflen, "10.%u.1.0", route);
	address.SetBase (buf, "255.255.255.0");
	address.Assign(dev);
	route++;
	
	/* Background Flows */
	SetBackgroundFlows("r2", "v2", 1, Seconds(25), Seconds(45) );
	SetBackgroundFlows("r2", "v2", 1, Seconds(30), Seconds(45) );
	SetBackgroundFlows("r2", "v2", 1, Seconds(35), Seconds(50) );
	SetBackgroundFlows("r2", "v2", 1, Seconds(40), Seconds(55) );

	// link down
	SetLinkLossRate(bottlenecks[2], 1, 0, Seconds(60), Seconds(65) );

	
	/* Setup global routing tables */
	Ipv4GlobalRoutingHelper::PopulateRoutingTables ();


	/* Destination */
	for (uint32_t i = 0; i < links+1; i++) // d
    {
		PacketSinkHelper sink ("ns3::TcpSocketFactory", 
								InetSocketAddress (Ipv4Address::GetAny (), 8000) );
		ApplicationContainer apps = sink.Install (d.Get(i) );
		apps.Start (Seconds(0.0) );
		//apps.Stop (Seconds(simTime) );
    }
	
	/* Sender */
	BulkApplication* bulkApp = 0;
    if (bUpSingle)
    {
      bulkApp = new BulkApplication("s0");
	  bulkApp->SetAttribute("SndBufSize", UintegerValue(sndBufSize) );
	  bulkApp->Run(Ipv4Address("10.0.1.2"), 8000, Seconds(beginTime[0]), Seconds(endTime[0]) );
    }
    for (uint32_t i = 1; i < links; i ++) //  indexed by s or d
    {
      // get one of ip addresses of d_i
      Ptr<Ipv4L3Protocol> ipv4 = d.Get(i)->GetObject<Ipv4L3Protocol>();
	  Ptr<Ipv4Interface> itf = ipv4->GetInterface(1); // 0 is loop interface

	  // app
      snprintf(buf, buflen, "s%u", i);
	  bulkApp = new BulkApplication(buf);
	  bulkApp->SetAttribute("SndBufSize", UintegerValue(sndBufSize) );
	  bulkApp->Run(itf->GetAddress(0).GetLocal(), 8000, Seconds(beginTime[i]), Seconds(endTime[i]) );
    }
    if (bDownSingle)
    {
      // get one of ip addresses of d_i
      Ptr<Ipv4L3Protocol> ipv4 = d.Get(links)->GetObject<Ipv4L3Protocol>();
      Ptr<Ipv4Interface> itf = ipv4->GetInterface(1); // 0 is loop interface
      
      // app
      snprintf(buf, buflen, "s%u", links);
      bulkApp = new BulkApplication(buf);
      bulkApp->SetAttribute("SndBufSize", UintegerValue(sndBufSize) );
      bulkApp->Run(itf->GetAddress(0).GetLocal(), 8000, Seconds(beginTime[links]), Seconds(endTime[links]) );
    }
	
	/* Simulation begins */
	Simulator::Stop (Seconds(simTime) );
	Simulator::Run ();
	Simulator::Destroy (); 
	return 0;

}

