/**
 * ifping main program
 *
 *  Archifishal Software, 2005
 *
 * $Id: $
 *
 * @file
 * @author Alex Macfarlane Smith
 */

#include <assert.h>
#include <ctype.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <netdb.h>

#include <sys/types.h>

#include <sys/errno.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <socklib.h>
#include <inetlib.h>

#include <netinet/in_systm.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <netinet/ip_icmp.h>

#include "kernel.h"
#include "swis.h"

#include "dnslib.h"

typedef enum
{
    false = 0,
    true  = 1
} bool;

static int  socket_handle = -1;
static bool verbose = false;

#define DBUGF(x) do { if (verbose) { printf x; } } while(0)

static int strcasecmp(const char *s1, const char *s2)
{
    while (*s1 != '\0')
    {
        char cs1 = tolower(*s1++);
        char cs2 = tolower(*s2++);

        if (cs1 != cs2)
        {
            return cs1 - cs2;
        }
    }

    return 0;
}

/* This borrowed from ping.c */

static int in_cksum(unsigned short *addr, int len)
{
	int nleft = len;
	unsigned short *w = addr;
	unsigned int sum = 0;
	unsigned short answer = 0;

	/*
	 * Our algorithm is simple, using a 32 bit accumulator (sum), we add
	 * sequential 16 bit words to it, and at the end, fold back all the
	 * carry bits from the top 16 bits into the lower 16 bits.
	 */
	while (nleft > 1)  {
		sum += *w++;
		nleft -= 2;
	}

	/* mop up an odd byte, if necessary */
	if (nleft == 1) {
		*(unsigned char *)(&answer) = *(unsigned char *)w ;
		sum += answer;
	}

	/* add back carry outs from top 16 bits to low 16 bits */
	sum = (sum >> 16) + (sum & 0xffff);	/* add hi 16 to low 16 */
	sum += (sum >> 16);			/* add carry */
	answer = ~sum;				/* truncate to 16 bits */
	return(answer);
}

static void exitFn(void)
{
    if (socket_handle != -1)
    {
        socketclose(socket_handle);
    }
}

static void syntaxDisplay(char *argv[])
{
    fprintf(stdout, "Syntax: %s -h | [-v] [-t <time>] <host> then "
                    "<command> [else <command>]\n", argv[0]);
}

static void syntaxExit(char *argv[])
{
    syntaxDisplay(argv);
    exit(EXIT_FAILURE);
}

static void launchCommand(int argc, char *argv[], int i, int success)
{
    char *command = NULL;
    char *commandPtr;
    int j;
    int len = 0;

    assert(success == 0 || success == 1);

    for (j = i; j < argc; j++)
    {
        len += strlen(argv[j]) + 1;
    }

    command = malloc((len + sizeof("%If x ") + 1) * sizeof(char));
    if (command == NULL)
    {
        fprintf(stderr, "%s: ran out of memory\n", argv[0]);
        exit(EXIT_FAILURE);
    }

    commandPtr = command;
    sprintf(commandPtr, "%%If %d ", success);
    commandPtr += sizeof("%If x ") - 1;

    for (j = i; j < argc; j++)
    {
        int argLen = strlen(argv[j]);
        memcpy(commandPtr, argv[j], argLen);
        commandPtr += argLen;
        *commandPtr++ = ' ';
    }
    *(commandPtr - 1) = '\0';

    DBUGF(("ifping: executing command '%s'\n", command));

    _kernel_system(command, 1);
}

static int sendPacket(unsigned short      ident,
                      unsigned short      seq,
                      struct sockaddr_in *addr)
{
    char         packet[sizeof(struct icmp)];
    struct icmp *icmp;
    int          sizeOfAddr = sizeof(struct sockaddr_in);

    memset(packet, 0, sizeof(packet));
    icmp = (struct icmp *) packet;

    icmp->icmp_type = ICMP_ECHO;
    icmp->icmp_code = htons(0);
    icmp->icmp_cksum = htons(0);
    icmp->icmp_seq = htons(seq);
    icmp->icmp_id = htons(ident);

    icmp->icmp_cksum = in_cksum((unsigned short *)packet, 64);

    DBUGF(("Sending:\n"));
    DBUGF(("  icmp->icmp_type = %x\n", icmp->icmp_type));
    DBUGF(("  icmp->icmp_code = %x\n", ntohs(icmp->icmp_code)));
    DBUGF(("  icmp->icmp_cksum = %x\n", ntohs(icmp->icmp_cksum)));
    DBUGF(("  icmp->icmp_seq = %x\n", ntohs(icmp->icmp_seq)));
    DBUGF(("  icmp->icmp_id = %x\n", ntohs(icmp->icmp_id)));

    return sendto(socket_handle, packet, 64, 0,
                  (struct sockaddr *) addr, sizeOfAddr);
}

static int receivePacket(unsigned short      ident,
                         unsigned short      seq,
                         struct sockaddr_in *addr)
{
    char                packet[sizeof(struct icmp) + sizeof(struct ip)];
    struct icmp        *icmp;
    int                 result;
    int                 sizeOfAddr = sizeof(struct sockaddr_in);

    result = recvfrom(socket_handle, packet, 64, 0,
                      (struct sockaddr *) addr, &sizeOfAddr);

    if (result == -1)
    {
        if (errno == EWOULDBLOCK)
        {
            return 0;
        }

        DBUGF(("  result was -1, errno = %d\n", errno));

        return -1;
    }

    DBUGF(("result was %d\n", result));

    icmp = (struct icmp *) (packet + sizeof(struct ip));

    DBUGF(("Received was:\n"));
    DBUGF(("  icmp->icmp_type = %x\n", icmp->icmp_type));

    if (icmp->icmp_type != ICMP_ECHOREPLY)
    {
        DBUGF(("Don't care about non-ICMP_ECHOREPLY packets\n"));
    }
    else
    {
        DBUGF(("  icmp->icmp_code = %x\n", ntohs(icmp->icmp_code)));
        DBUGF(("  icmp->icmp_cksum = %x\n", ntohs(icmp->icmp_cksum)));
        DBUGF(("  icmp->icmp_seq = %x\n", ntohs(icmp->icmp_seq)));
        DBUGF(("  icmp->icmp_id = %x\n", ntohs(icmp->icmp_id)));

        /* FIXME alexms 20/01/2004: Should really sanity check the
         * checksum */

        if (ident == ntohs(icmp->icmp_id) &&
            seq == ntohs(icmp->icmp_seq))
        {
            DBUGF(("id/seq match!\n"));
            /* Match! */
            return 1;
        }

        DBUGF(("mismatch: id = %d/%d, seq = %d/%d\n", ident, icmp->icmp_id,
               seq, icmp->icmp_seq));
    }

    return 0;
}

int main(int argc, char *argv[])
{
    struct protoent   *proto;
    struct sockaddr_in addr;
    int                i;
    char              *host;
    dns_t             *dns;
    dns_status_t       dnsStatus;
    int                result;
    int                ioctlFlag;
    unsigned short     ident        = _swi(OS_ReadMonotonicTime, _RETURN(0));
    unsigned short     ntransmitted = -1;
    unsigned int       maxTime      = 100;
    int                startTime    = _swi(OS_ReadMonotonicTime, _RETURN(0));
    int                nextTime     = startTime;
    int                currentTime;

    atexit(exitFn);

    if (argc == 1)
    {
        syntaxExit(argv);
    }

    for (i = 1; i < argc; i++)
    {
        if (argv[i][0] != '-')
        {
            break;
        }

        if (strcmp(argv[i], "-h") == 0)
        {
            fprintf(stdout, "ifping version 1.19b (%s, %s) ", __DATE__,
                    __TIME__);
            fprintf(stdout, " Archifishal Software, 2005\n\n");
            syntaxDisplay(argv);
            fprintf(stdout, "\nifping takes the following options:\n");
            fprintf(stdout, "    -h             displays this help text\n");
            fprintf(stdout, "    -v             verbose output\n");
            fprintf(stdout, "    -t <time>      time in seconds to wait "
                            "before giving up on getting a reply\n");
            fprintf(stdout, "    <host>         the host to ping\n");
            fprintf(stdout, "    then <command> the command to execute in "
                            "the event of a successful ping\n");
            fprintf(stdout, "    else <command> the command to execute in "
                            "the event of a unsuccessful ping\n");
            exit(EXIT_SUCCESS);
        }
        else if (strcmp(argv[i], "-v") == 0)
        {
            verbose = true;
        }
        else if (strcmp(argv[i], "-t") == 0)
        {
            i++;

            if (i < argc)
            {
                maxTime = atoi(argv[i]) * 100;
            }
        }
    }

    if (i + 2 >= argc)
    {
        syntaxExit(argv);
    }

    if (strcasecmp(argv[i + 1], "Then") != 0)
    {
        syntaxExit(argv);
    }

    /* i now points at the host to ping */
    host = argv[i];

    i++;
    /* i now points at the command to execute */

    /* Now resolve the hostname... */
    dns = dns_gethostbyname(host);
    if (dns == NULL)
    {
        fprintf(stderr, "%s: dns lookup creation failed\n", argv[0]);
        exit(EXIT_FAILURE);
    }

    memset(&addr, 0, sizeof(addr));
    addr.sin_family=AF_INET;

    do
    {
        dnsStatus = dns_check(dns);

        DBUGF(("Checking dns resolve status... %d\n", dnsStatus));
    } while (dnsStatus != dns_complete_success &&
             dnsStatus != dns_complete_failure);

    switch (dnsStatus)
    {
        case dns_complete_failure:
            dns_dispose(dns);
            launchCommand(argc, argv, i, 0);
            exit(EXIT_SUCCESS);

        case dns_complete_success:
        {
            struct hostent *ent;

            ent = dns_getanswer(dns);

            addr.sin_addr.s_addr = *((unsigned int *)(void *)(*ent).h_addr);

            DBUGF(("resolved host '%s' ok! (%s)\n", host,
                   inet_ntoa(addr.sin_addr)));

            dns_dispose(dns);
            break;
        }
    }

    proto = getprotobyname("icmp");
    if (proto == NULL)
    {
      fprintf(stderr, "%s: unknown protocol icmp\n", argv[0]);
      exit(EXIT_FAILURE);
    }

    socket_handle = socket(AF_INET, SOCK_RAW, proto->p_proto);

    DBUGF(("created icmp socket %d\n", socket_handle));

    if (socket_handle < 0)
    {
        fprintf(stderr, "%s: socket error %d\n", argv[0], errno);
        exit(EXIT_FAILURE);
    }

    ioctlFlag = 1;
    result = socketioctl(socket_handle, FIONBIO, &ioctlFlag);
    if (result == -1)
    {
        fprintf(stderr, "%s: ioctl error (FIONBIO) %d\n", argv[0], errno);
        exit(EXIT_FAILURE);
    }

    ioctlFlag = 1;
    result = socketioctl(socket_handle, FIOSLEEPTW, &ioctlFlag);
    if (result == -1)
    {
        fprintf(stderr, "%s: ioctl error (FIOSLEEPTW) %d\n", argv[0], errno);
        exit(EXIT_FAILURE);
    }

    DBUGF(("Entering send/receive loop...\n"));

    do
    {
        result = 0;

        if (nextTime <= _swi(OS_ReadMonotonicTime, _RETURN(0)))
        {
            result = sendPacket(ident, ++ntransmitted, &addr);

            if (result == -1)
            {
                DBUGF(("sendto failed, errno was %d\n", errno));
            }
            else if (result != 64)
            {
                DBUGF(("sendto only sent %d bytes\n", result));
                result = -1;
            }
            else
            {
                DBUGF(("sent %d bytes\n", result));
            }

            nextTime = _swi(OS_ReadMonotonicTime, _RETURN(0)) + 100;
        }

        if (result != -1)
        {
            result = receivePacket(ident, ntransmitted, &addr);
        }

        currentTime = _swi(OS_ReadMonotonicTime, _RETURN(0));
    } while ((currentTime - startTime < maxTime) && result == 0);

    DBUGF(("ifping: closing icmp socket %d\n", socket_handle));

    socketclose(socket_handle);
    socket_handle = -1;

    switch (result)
    {
        case -1:
            DBUGF(("recvfrom failed, error was %d\n", errno));
            result = 0;
            break;

        case 0:
            DBUGF(("ifping timed out waiting for a response\n"));
            break;

        case 1:
            DBUGF(("ifping got a successful packet!\n"));
            break;
    }

    launchCommand(argc, argv, i, result);
}
