/*
 * mcastseed.c - Multicast sender for huge streams to be piped to other programs (partitions cloning)
 *
 * Copyright 2016 by Ludovic Pouzenc <ludovic@pouzenc.fr>
 *
 * Greatly inspired from examples written by tmouse, July 2005
 * http://cboard.cprogramming.com/showthread.php?t=67469
 */
#define _GNU_SOURCE /* See feature_test_macros(7) */
#include "config.h"

#include <unistd.h> /* close() */
#include <stdio.h> /* fprintf(), stderr */
#include <stdlib.h> /* atoi(), EXIT_SUCCESS */
#include <string.h> /* strlen() */
#include <sys/select.h> /* select(), FD_ZERO(), FD_SET() */
#include "sockets.h"

#define READ_BUF_LEN 256
#define MAX_PENDING_CONNECTIONS 256
#define MAX_CLIENTS 256
#define MTU 1500
/* Linux IPv6 fragmentation don't output ethernet frames larger than 1470 when MTU==1500 */
#define MULTICAST_MAX_PAYLOAD_SIZE (MTU-40-8-(14+30))

#define DEFAULT_MCAST_IP_STR "ff02::114"
#define DEFAULT_PORT_STR "9000"
#define DEFAULT_MCAST_TTL 1

/* Cmdline Arguments */
char *prog_name = NULL;
char *mcast_ip = NULL;
char *port = NULL;
int mcast_ttl = 0;

/* Sockets as global, used everywhere, even in die() */
int mcast_sock = -1; /* Multicast socket for sending data */
int ucast_sock = -1; /* Unicast socket for havee feedback from clients */

/* Socket related data */
struct addrinfo *mcast_addr = NULL;
struct client {
	int sock;
	struct sockaddr addr;
	int state;
} clients[MAX_CLIENTS];
int clients_next = 0;

/* Buffer used for earch read() */
char readbuf[READ_BUF_LEN];

/* Strings to print out representation of various states of the program */
const char * const state_str[] = {
	"start",
	"send_hello",
	"accept_pending_clients_or_wait_a_bit",
	"start_job",
	"send_data",
	"wait_all_finalize_job",
	"is_there_more_job"
};

/* Some boring funcs you didn't want to read now */
void die(char* msg);
void usage(char *msg);
void arg_parse(int argc, char* argv[]);
void fsm_trace(int state);
void setup_sockets();
void unsetup_sockets();

/* Parts of the "protocol", definitions are after main() */
int send_hello();
int accept_pending_clients_or_wait_a_bit();
int start_job();
int send_data();
int wait_all_finalize_job();
int is_there_more_job();

int main(int argc, char *argv[]) {
	int state = 1; /* state of the "protocol" state machine */
	int res;
	arg_parse(argc, argv);
	setup_sockets();

	/* Finite state machine */
	while ( state > 0 ) {
		fsm_trace(state);
		switch ( state ) {
			case 1: res = send_hello(); state = (res==0)?2:-1; break;
			case 2: res = accept_pending_clients_or_wait_a_bit();
				if (res==0) state = 2; /* Some clients has just come in, try to get more */
				else if	(res==1) state = 1; /* Nothing new. Keep accepting clients after another hello */
				else if (res==2) state = 3; /* Wanted clients are accepted */
				else state = -2;
				break;
			case 3: res = start_job();
				if (res==0) state = 3; /* Keep trying to convince every client to start */
				else if (res==1) state = 4; /* All clients have started the job pipe */
				else if	(res==2) state = 4; /* There is dead clients but all alive are ready to go */
				else state = -3;
				break;
			case 4: res = send_data();
				if (res==0) state = 4;
				else if (res==1) state = 5; /* All data sent */
				else state = -4;
				break;
			case 5: res = wait_all_finalize_job();
				if (res==0) state = 5;
				else if (res==1) state = 6;
				else state = -5;
			case 6: res = is_there_more_job();
				if (res==0) state = 0;
				else if (res==1) state = 3;
				else state = -6;
				break;
		}
	}
	fsm_trace(state);

	unsetup_sockets();

	if ( state < 0 )
	  return -state;

	return EXIT_SUCCESS;
}

int send_hello() {
	ssize_t nwrite;
	const char *payload = "hello";
	int paylen = strlen(payload);

	nwrite = sendto(mcast_sock, payload, paylen, 0, mcast_addr->ai_addr, mcast_addr->ai_addrlen);
	if ( nwrite < 0 ) {
		perror("sendto() failed");
		return -1;
	}
	if ( nwrite < paylen ) {
		fprintf(stderr, "%s", "Short packet sent");
	}

	return 0;
}

int accept_pending_clients_or_wait_a_bit() {
	struct timeval timeout;
	fd_set readfds, exceptfds;
	ssize_t nread;
	int res;

	FD_ZERO(&readfds);
	FD_ZERO(&exceptfds);
	FD_SET(0,&readfds);
	FD_SET(ucast_sock,&readfds);
	FD_SET(ucast_sock,&exceptfds);
	timeout.tv_sec = 2;
	timeout.tv_usec = 0;

	res = select(ucast_sock+1, &readfds, NULL, &exceptfds, &timeout);
	if ( res < 0 ) {
		perror("select() failed");
		return -1;
	}

	if ( res > 0 ) {
		if (FD_ISSET(ucast_sock, &readfds)) {
			/*TODO : this assumes that the event is an accept() while ones could be send data there */
			if ( clients_next >= MAX_CLIENTS ) {
				fprintf(stderr, "%s\n", "Bouncing client, MAX_CLIENTS reached");
				close(accept(ucast_sock, NULL, 0));
			} else {
				socklen_t addrlen = sizeof(struct sockaddr);
				clients[clients_next].sock = accept(ucast_sock, &(clients[clients_next].addr), &addrlen);
				clients[clients_next].state = 0;
				printf("Connected client on fd %i\n", clients[clients_next].sock);
				clients_next++;
			}
		}
		/*TODO : drop this keybord read with accept(), this is not portable */
		if ( FD_ISSET(0, &readfds)) {
			nread = read(0, readbuf, READ_BUF_LEN);
			if ( nread <= 0 ) {
				fprintf(stderr, "%s\n", "lost stdin");
			}
			/* User wants to go now */
			return 2;
		}
		if (FD_ISSET(ucast_sock, &exceptfds)) {
			fprintf(stderr, "%s\n", "unhandled except on ucast_sock");
			return -2;
		}
	}
	if (res == 0 ) {
		/* Nothing happened before timeout */
		return 1;
	}
	return 0;
}

int start_job() {
	struct timeval timeout;
	fd_set readfds, exceptfds;
	ssize_t nread, nwrite;
	int all_ready, all_non_dead_ready;
	int i, res;
	int client_sock;
	const char *payload = "start";
	int paylen = strlen(payload);

	nwrite = sendto(mcast_sock, payload, paylen, 0, mcast_addr->ai_addr, mcast_addr->ai_addrlen);
	if ( nwrite < 0 ) {
		perror("sendto() failed");
		return -1;
	}
	if ( nwrite < paylen ) {
		fprintf(stderr, "%s", "Short packet sent");
	}

	all_ready = 1;
	all_non_dead_ready = 1;

	FD_ZERO(&readfds);
	FD_ZERO(&exceptfds);
	for ( i=0; i<clients_next; i++) {
		FD_SET(clients[i].sock,&readfds);
		FD_SET(clients[i].sock,&exceptfds);
	}
	timeout.tv_sec = 2;
	timeout.tv_usec = 0;
	res = select(clients_next, &readfds, NULL, &exceptfds, &timeout);
	if ( res < 0 ) {
		perror("select() failed");
		return -1;
	}

	if ( res > 0 ) {
		for ( i=0; i<clients_next; i++) {
			client_sock = clients[i].sock;
			if (FD_ISSET(client_sock, &readfds)) {
				printf("todo info from client %i\n", i);
				nread = read(client_sock, readbuf, 5);
				if ( nread <= 0 ) {
					fprintf(stderr, "lost client %i\n", i);
					clients[i].state = 2;
				} else if ( nread < 5 ) {
					fprintf(stderr, "short data from %i\n", i);
					clients[i].state = 2;
				} else if ( strncmp("ready", readbuf, 5) != 0 ) {
					fprintf(stderr, "unexpected data from %i\n", i);
					clients[i].state = 2;
				} else {
					/* Received "ready" ack from client */
					clients[i].state = 1;
				}
			}
			if (FD_ISSET(clients[i].sock, &exceptfds)) {
				fprintf(stderr, "unhandled except on client %i\n", i);
					clients[i].state = 2;
			}
			all_ready &= (clients[i].state == 1);
			if ( clients[i].state != 2)
				all_non_dead_ready &= (clients[i].state == 1);
		}
	}
	/* (res == 0 ) nothing happened before timeout */

	if ( all_ready )
		return 1;
	if ( all_non_dead_ready )
		return 2;

	return 0;
}

void send_fake(char buf[], int paylen, int i) {
	*( (uint32_t *) buf+1 ) = htonl(i);
	snprintf(buf+28, 6, "%05i", i);
	*( (char *) buf+33 ) = ')';
	sendto(mcast_sock, buf, paylen, 0, mcast_addr->ai_addr, mcast_addr->ai_addrlen);
}

int send_data() {
	ssize_t nwrite;
	char buf[MULTICAST_MAX_PAYLOAD_SIZE];
	int paylen = MULTICAST_MAX_PAYLOAD_SIZE;
	int i;

	/* XXX Dummy */
	memset(buf, '.', MULTICAST_MAX_PAYLOAD_SIZE-1);
	buf[MULTICAST_MAX_PAYLOAD_SIZE-1]='\n';
	strcpy(buf, "dataXXXXJe suis a la plage (XXXXX)");

	send_fake(buf, paylen, 5);
	send_fake(buf, paylen, 4);
	send_fake(buf, paylen, 3);

	for (i=6; i<=100000; i+=2) {
		send_fake(buf, paylen, i);
	}
	for (i=7; i<=100000; i+=2) {
		send_fake(buf, paylen, i);
	}

	send_fake(buf, paylen, 1);
	send_fake(buf, paylen, 1);
	send_fake(buf, paylen, 2);

	*( (uint32_t *) buf+1 ) = htonl(3);
	buf[21]='m', buf[22]='e', buf[23]='r'; buf[24]='.'; buf[25]='\n'; paylen = 26;
	nwrite = sendto(mcast_sock, buf, paylen, 0, mcast_addr->ai_addr, mcast_addr->ai_addrlen);
	if ( nwrite < 0 ) {
		perror("sendto() failed");
		return -1;
	}
	if ( nwrite < paylen ) {
		fprintf(stderr, "%s", "Short packet sent");
	}

	return 1;
}


int wait_all_finalize_job() {
	struct timeval timeout;
	fd_set readfds, exceptfds;
	ssize_t nread, nwrite;
	int all_non_dead_done;
	int i, res;
	int client_sock;
	char buf[] = "end:XXXX";
	int paylen = strlen(buf);

	*( (uint32_t *) buf+1 ) = htonl(100000);
	nwrite = sendto(mcast_sock, buf, paylen, 0, mcast_addr->ai_addr, mcast_addr->ai_addrlen);
	if ( nwrite < 0 ) {
		perror("sendto() failed");
		return -1;
	}
	if ( nwrite < paylen ) {
		fprintf(stderr, "%s", "Short packet sent");
	}

	all_non_dead_done = 1;

	FD_ZERO(&readfds);
	FD_ZERO(&exceptfds);
	for ( i=0; i<clients_next; i++) {
		FD_SET(clients[i].sock,&readfds);
		FD_SET(clients[i].sock,&exceptfds);
	}
	timeout.tv_sec = 2;
	timeout.tv_usec = 0;
	res = select(clients_next, &readfds, NULL, &exceptfds, &timeout);
	if ( res < 0 ) {
		perror("select() failed");
		return -1;
	}

	if ( res > 0 ) {
		for ( i=0; i<clients_next; i++) {
			client_sock = clients[i].sock;
			if (FD_ISSET(client_sock, &readfds)) {
				printf("todo info from client %i\n", i);
				nread = read(client_sock, readbuf, 5);
				if ( nread <= 0 ) {
					fprintf(stderr, "lost client %i\n", i);
					clients[i].state = 2;
				} else if ( nread < 5 ) {
					fprintf(stderr, "short data from %i\n", i);
					clients[i].state = 2;
				} else if ( strncmp("done.", readbuf, 5) != 0 ) {
					fprintf(stderr, "unexpected data from %i\n", i);
					clients[i].state = 2;
				} else {
					/* Received "done." ack from client */
					clients[i].state = 3;
				}
			}
			if (FD_ISSET(clients[i].sock, &exceptfds)) {
				fprintf(stderr, "unhandled except on client %i\n", i);
					clients[i].state = 2;
			}
			if ( clients[i].state != 2)
				all_non_dead_done &= (clients[i].state == 3);
		}
	}
	/* (res == 0 ) nothing happened before timeout */

	if ( all_non_dead_done )
		return 1;

	return 0;
}


int is_there_more_job() {
	return 0;
}


void die(char* msg) {
	fprintf(stderr, "%s\n", msg);
	if (mcast_sock > 0)
		close(mcast_sock);
	if (ucast_sock > 0)
		close(ucast_sock);
	exit(EXIT_FAILURE);
}

void usage(char *msg) {
	char ubuf[256];
	if ( msg != NULL )
		fprintf(stderr, "%s\n", msg);
	ubuf[0] = '\0';
	snprintf(ubuf, 255, "Usage: %s [port] [mcast_ip] [mcast_ttl]\n", prog_name);
	die(ubuf);
}

void arg_parse(int argc, char* argv[]) {
	prog_name = argv[0];
	if ( argc > 3 )
		usage("Too many arguments");
	port = (argc >= 2)?argv[1]:DEFAULT_PORT_STR;
	mcast_ip = (argc >= 3)?argv[2]:DEFAULT_MCAST_IP_STR;
	mcast_ttl = (argc >= 4)?atoi(argv[3]):DEFAULT_MCAST_TTL;
	if ( mcast_ttl < 1 || mcast_ttl > 64 )
		mcast_ttl = 1;
}

void fsm_trace(int state) {
	static int prev_state = 0;

	if ( state < 0 ) {
		fprintf(stderr, "Abnormal exit condition %i (from %s)\n", state, state_str[prev_state]);
	} else	if ( prev_state != state) {
		if ( state == 0 ) {
			fprintf(stderr, "Normal exit (from %s)\n",  state_str[prev_state]);
		} else {
			fprintf(stderr, "Now in %s (from %s)\n", state_str[state], state_str[prev_state]);
		}
		prev_state = state;
	}
}

void setup_sockets() {
	/* Setup ucast_sock */
	ucast_sock = ucast_server_socket(port, MAX_PENDING_CONNECTIONS);
	if(ucast_sock < 0)
			usage("Could not setup unicast socket. Wrong args given ?");

	/* Setup mcast_sock */
	mcast_sock = mcast_send_socket(mcast_ip, port, mcast_ttl, &mcast_addr);
	if(mcast_sock < 0)
			usage("Could not setup multicast socket. Wrong args given ?");
}

void unsetup_sockets() {
	if ( ucast_sock > 0 ) {
		close(ucast_sock);
		ucast_sock = 0;
	}

	if ( mcast_sock > 0 ) {
		close(mcast_sock);
		mcast_sock = 0;
		if ( mcast_addr ) {
			freeaddrinfo(mcast_addr);
			mcast_addr = 0;
		}
	}
}