/*
 * MPTCP MDT congestion control
 *
 * Yu Cao: cyAnalyst@126.com
 */

#include <linux/mm.h>
#include <linux/skbuff.h>
#include <linux/inet_diag.h>
#include <net/tcp.h>
#include <net/mptcp.h>
#include <linux/module.h>
#include <linux/tcp.h>
#include <linux/time.h>
#include <linux/timex.h>
#include <linux/rtc.h>

#include "net/mptcp_xmp.h"

static int mptcp_xmp_shift_usrtt __read_mostly	= 3;
static int mptcp_xmp_shift_g __read_mostly		= 4;
static int mptcp_xmp_ca_adder __read_mostly		= 1;
static int mptcp_xmp_reducer __read_mostly		= 2;

module_param(mptcp_xmp_shift_usrtt, int, 0644);
MODULE_PARM_DESC(mptcp_xmp_shift_usrtt, "smooth rtt, 2^#");
module_param(mptcp_xmp_shift_g, int, 0644);
MODULE_PARM_DESC(mptcp_xmp_shift_g, "smooth g, 2^#");
module_param(mptcp_xmp_ca_adder, int, 0644);
MODULE_PARM_DESC(mptcp_xmp_ca_adder, "increase cwnd by # packets per RTT");
module_param(mptcp_xmp_reducer, int, 0644);
MODULE_PARM_DESC(mptcp_xmp_reducer, "decrease cwnd by a factor of 1/# if congested");


static inline u64 mptcp_xmp_scale(u32 val, int scale)
{
	return (u64) val << scale;
}


static inline u64 mptcp_xmp_rate(u32 cwnd, s32 rtt_us)
{
	return div64_u64(mptcp_xmp_scale(cwnd, MPTCP_XMP_SCALE), rtt_us);
}


static inline u32 mptcp_xmp_ssthresh(struct tcp_sock *tp)
{
	return  min(tp->snd_ssthresh, tp->snd_cwnd-1);
}


static void mptcp_xmp_receive_ece(struct sock *sk, struct sk_buff *skb)
{
	//if (!sk || !skb)
	//	return;

	struct tcphdr * hd = tcp_hdr(skb);
	if ((!hd->ece && !hd->cwr) || hd->syn)
		return ;
	
	struct mptcp_xmp *ca = inet_csk_ca(sk);
	if (hd->ece)
		ca->cnt_ece += 1;
	if (hd->cwr)
		ca->cnt_ece += 2;
	hd->ece = 0;
	hd->cwr = 0;

	// reduce cwnd only once
	if (ca->state == MPTCP_XMP_STATE_NORMAL) {
		ca->state = MPTCP_XMP_STATE_CWND_REDUCED;

		struct tcp_sock *tp = tcp_sk(sk);
		ca->cwr_seq = tp->snd_nxt;
		ca->old_cwnd = tp->snd_cwnd;

		if (tp->snd_cwnd > tp->snd_ssthresh) { // ca
			u32 reduced = tp->snd_cwnd >> 1;
			if (sysctl_mptcp_xmp_mode == 1)
				reduced = (reduced * ca->alpha) >> MPTCP_XMP_SCALE;
			else if (sysctl_mptcp_xmp_mode == 2 && mptcp_xmp_reducer > 0)
				reduced = div64_u64(mptcp_xmp_scale(tp->snd_cwnd, MPTCP_XMP_SCALE), mptcp_xmp_reducer) >> MPTCP_XMP_SCALE;
			reduced = (reduced == 0) ? 1 : reduced;
			tp->snd_cwnd = (tp->snd_cwnd > reduced) ? (tp->snd_cwnd - reduced) : 0;
			tp->snd_cwnd = (tp->snd_cwnd < 2) ? 2 : tp->snd_cwnd;
		}

		tp->snd_ssthresh = mptcp_xmp_ssthresh(tp);
	}
}


static void mptcp_xmp_send_ece(struct sock *sk, struct sk_buff *skb)
{
	//if (!sk || !skb)
	//	return;
	
	struct mptcp_xmp *ca = inet_csk_ca(sk);
	struct tcphdr * hd = tcp_hdr(skb);
	switch(ca->cnt_ce){
		case (u8)0:
			hd->cwr = 0;
			hd->ece = 0;
			break;
		case (u8)1:
			hd->cwr = 0;
			hd->ece = 1;
			ca->cnt_ce = 0;
			break;
		case (u8)2:
			hd->cwr = 1;
			hd->ece = 0;
			ca->cnt_ce = 0;
			break;
		default:
			hd->cwr = 1;
			hd->ece = 1;
			ca->cnt_ce -= 3;
	}
}


static u64 mptcp_xmp_weight(struct mptcp_cb* mpcb, struct sock *sk, u64 *conn_rate)
{
	if (!mpcb) {
		struct mptcp_xmp *ca = inet_csk_ca(sk);
		return ca->weight;
	}
	
	u64 total_rate = 0;
	s32 min_rtt = 0x7fffffff;
	
	struct sock *sub_sk;
	mptcp_for_each_sk(mpcb, sub_sk) {
		struct mptcp_xmp *sub_ca = inet_csk_ca(sub_sk);
		if (sub_sk->sk_state == TCP_ESTABLISHED 
		&& (sub_ca->srtt_us > 1)) { // srtt_us is initialized by 1
			total_rate += sub_ca->instant_rate;
			if (min_rtt > sub_ca->srtt_us)
				min_rtt = sub_ca->srtt_us;
		}
	}
	*conn_rate = total_rate;
	
	struct mptcp_xmp *ca = inet_csk_ca(sk);
	if (total_rate)
		return div64_u64(mptcp_xmp_scale(ca->instant_rate * ca->srtt_us, MPTCP_XMP_SCALE), min_rtt * total_rate);
	else
		return ca->weight;
}


static void mptcp_xmp_reset(struct sock *sk)
{
	struct mptcp_xmp *ca = inet_csk_ca(sk);
	ca->beg_seq = tcp_sk(sk)->snd_nxt;
	ca->instant_rate = 0;
	ca->baseRTT_us = 0x7fffffff;
	ca->minRTT_us = 0x7fffffff;
	ca->state = MPTCP_XMP_STATE_NORMAL;
	ca->weight = mptcp_xmp_scale(1, MPTCP_XMP_SCALE);
	ca->alpha = 0;
	ca->cnt_ece = 0;
	ca->cwr_seq = 0;
	ca->old_cwnd = tcp_sk(sk)->snd_cwnd;
}


void mptcp_xmp_init(struct sock *sk)
{
	tcp_sk(sk)->xmp = 1;
	
	struct mptcp_xmp *ca = inet_csk_ca(sk);
	ca->receive_ece = mptcp_xmp_receive_ece;
	ca->send_ece = mptcp_xmp_send_ece;
	ca->srtt_us = 1;
	ca->cnt_ce = 0;
	ca->adder = 0;
	mptcp_xmp_reset(sk);
}
EXPORT_SYMBOL_GPL(mptcp_xmp_init);


void mptcp_xmp_release(struct sock *sk)
{
	tcp_sk(sk)->xmp = 0;
}
EXPORT_SYMBOL_GPL(mptcp_xmp_release);


void mptcp_xmp_pkts_acked(struct sock *sk, u32 cnt, s32 rtt_us)
{
	if (rtt_us <= 0)
		return ;
	
	struct mptcp_xmp *ca = inet_csk_ca(sk);
	if (rtt_us < ca->baseRTT_us)
		ca->baseRTT_us = rtt_us;
	if (rtt_us < ca->minRTT_us)
		ca->minRTT_us = rtt_us;
	ca->srtt_us = ca->srtt_us - (ca->srtt_us >> mptcp_xmp_shift_usrtt) + (rtt_us >> mptcp_xmp_shift_usrtt);
}
EXPORT_SYMBOL_GPL(mptcp_xmp_pkts_acked);


static void mptcp_xmp_cong_avoid(struct sock *sk, u32 ack, u32 in_flight)
{
	struct tcp_sock *tp = tcp_sk(sk);
	struct mptcp_cb *mpcb = tp->mpcb;
	struct mptcp_xmp *ca = inet_csk_ca(sk);

	if (!mpcb || !tp->mptcp || !tp->xmp)
		return ;
	if (!tp->mpc || !(tp->ecn_flags & TCP_ECN_OK)) {
		tcp_reno_cong_avoid(sk, ack, in_flight);
		return;
	}

	/* // ???
	if (!tcp_is_cwnd_limited(sk, in_flight))
		return;
	*/
	
	/* per RTT */
	if (after(ack, ca->beg_seq) ) {
		
		u32 cwnd = tp->snd_cwnd;
		u64 diff = 0;
		if (!sysctl_mptcp_xmp_mode || ca->state != MPTCP_XMP_STATE_NORMAL) {
			cwnd = ca->old_cwnd;
			// estimate the number of backlogged packets, just for my testing..
			if (ca->minRTT_us != 0x7fffffff) {
				diff = div64_u64(mptcp_xmp_scale(ca->minRTT_us - ca->baseRTT_us, MPTCP_XMP_SCALE), 
					 ca->minRTT_us);
				diff *= cwnd;
			}
		}

		// update parameters
		ca->alpha = ca->alpha - (ca->alpha >> mptcp_xmp_shift_g) 
		+ div64_u64(mptcp_xmp_scale(ca->cnt_ece, MPTCP_XMP_SCALE-mptcp_xmp_shift_g), cwnd);
		
		if (ca->alpha > (u64) 1 << MPTCP_XMP_SCALE) 
			ca->alpha = (u64) 1 << MPTCP_XMP_SCALE;

		ca->instant_rate = mptcp_xmp_rate(cwnd, ca->srtt_us); // pkts / us

		u64 total_rate = 0;
		ca->weight = mptcp_xmp_weight(tp->mpcb, sk, &total_rate);

		// in the safe area, increase cwnd
		if (ca->state == MPTCP_XMP_STATE_NORMAL) {
			if (tp->snd_cwnd <= tp->snd_ssthresh) // slow start
				tcp_slow_start(tp);
			else { // congestion avoidance
				ca->adder += ca->weight * mptcp_xmp_ca_adder;
				u32 pkts = ca->adder >> MPTCP_XMP_SCALE;
				if (pkts > 0) {
					tp->snd_cwnd += pkts;
					ca->adder -= pkts << MPTCP_XMP_SCALE;
				}
			}
		}

		// output
		printk (KERN_DEBUG"xmp %llu %u-%u.%u.%u.%u-%u.%u.%u.%u "
			"cw:%u:%u:%u:%d:%d "
			"rw:%u "
			"rtt:%d:%d:%d "
			"diff:%llu "
			"alp:%llu "
			"wei:%llu "
			"rate:%llu:%llu "
			"\n", 
			ktime_to_us(ktime_get_real()), 
			tp->inet_conn.icsk_inet.inet_dport, 
			NIPQUAD(sk->__sk_common.skc_rcv_saddr), 
			NIPQUAD(sk->__sk_common.skc_daddr), 
			cwnd, tp->snd_cwnd, ca->old_cwnd, ca->cnt_ece, ca->state, 
			tp->snd_wnd, 
			ca->srtt_us, ca->baseRTT_us, ca->minRTT_us, 
			diff, 
			ca->alpha, 
			ca->weight, 
			ca->instant_rate, total_rate
			);
		
		// for the next round
		ca->beg_seq  = tp->snd_nxt;
		ca->cnt_ece = 0;
		ca->minRTT_us = 0x7fffffff;
		ca->old_cwnd = tp->snd_cwnd;
	}
	/* Use normal slow start */
	else if (tp->snd_cwnd <= tp->snd_ssthresh && ca->state == MPTCP_XMP_STATE_NORMAL)
		tcp_slow_start(tp);

	// return to the normal state
	if (ca->state != MPTCP_XMP_STATE_NORMAL 
	&& !before(ack, ca->cwr_seq) ) {
		ca->state = MPTCP_XMP_STATE_NORMAL;
	}
}


void mptcp_xmp_set_state(struct sock *sk, u8 ca_state)
{
	struct tcp_sock * tp = tcp_sk(sk);
	if (!tp->mpc || !(tp->ecn_flags & TCP_ECN_OK))
		return;
	
	struct mptcp_xmp *ca = inet_csk_ca(sk);
	bool bShow = false;
	if (ca_state == TCP_CA_Recovery) {
		bShow = true;
		ca->instant_rate = 0;
		if (ca->srtt_us > 1)
			ca->instant_rate = mptcp_xmp_rate(tp->snd_cwnd, ca->srtt_us);
	} else if (ca_state == TCP_CA_Loss) {
		bShow = true;
		mptcp_xmp_reset(sk);
	} else if (ca_state == TCP_CA_CWR) {
		bShow = true;
	} 
	
	if (bShow) {
		printk (KERN_DEBUG"loss %llu %u-%u.%u.%u.%u-%u.%u.%u.%u "
			"cw:%u:%u:%d "
			"rw:%u "
			"rtt:%d:%d:%d "
			"st:%u:%u "
			"\n", 
			ktime_to_us(ktime_get_real()), 
			tp->inet_conn.icsk_inet.inet_dport, 
			NIPQUAD(sk->__sk_common.skc_rcv_saddr), 
			NIPQUAD(sk->__sk_common.skc_daddr), 
			tp->snd_cwnd, ca->old_cwnd, ca->cnt_ece, 
			tp->snd_wnd, 
			ca->srtt_us, ca->baseRTT_us, ca->minRTT_us, 
			inet_csk(sk)->icsk_ca_state, ca_state
			);
	}
}
EXPORT_SYMBOL_GPL(mptcp_xmp_set_state);


void mptcp_xmp_cwnd_event(struct sock *sk, enum tcp_ca_event event)
{
	if(event == CA_EVENT_CWND_RESTART || event == CA_EVENT_TX_START)
		mptcp_xmp_reset(sk);
}
EXPORT_SYMBOL_GPL(mptcp_xmp_cwnd_event);


static struct tcp_congestion_ops mptcp_xmp_ops __read_mostly = {
	.flags		= TCP_CONG_RTT_STAMP,
	.init			= mptcp_xmp_init,
	.release		= mptcp_xmp_release, 
	.ssthresh		= tcp_reno_ssthresh,
	.cong_avoid	= mptcp_xmp_cong_avoid,
	.min_cwnd	= tcp_reno_min_cwnd,
	.pkts_acked	= mptcp_xmp_pkts_acked,
	.set_state		= mptcp_xmp_set_state,
	.cwnd_event	= mptcp_xmp_cwnd_event,

	.owner		= THIS_MODULE,
	.name		= "xmp",
};


static int __init mptcp_xmp_register(void)
{
	BUILD_BUG_ON(sizeof(struct mptcp_xmp) > ICSK_CA_PRIV_SIZE);
	return tcp_register_congestion_control(&mptcp_xmp_ops);
}


static void __exit mptcp_xmp_unregister(void)
{
	sysctl_mptcp_xmp_mode = 0;
	tcp_unregister_congestion_control(&mptcp_xmp_ops);
}


module_init(mptcp_xmp_register);
module_exit(mptcp_xmp_unregister);

MODULE_AUTHOR("Yu Cao, Enhuan Dong");
MODULE_LICENSE("GPL");
MODULE_DESCRIPTION("ECN-based MultiPath Congestion Control (XMP) for Datacenters");
MODULE_VERSION("0.1");

