#include <stdio.h>

typedef struct _RC4Con RC4Con;

struct _RC4Con {
    int x;
    int y;
    int state[256];
};

static void
init(RC4Con *rc4, const char *key, int key_size, int loopN)
{
    int tmp,i,n,l;
	
    for(i=0; i<256; i++)
    {
	rc4->state[i] = i;
    }

    i=0;
    n=0;
    l=0;
    rc4->x = 0;
    rc4->y = 0;

    for(l=0; l<loopN; l++) 
    {
	for(i=0; i<256; i++)
	{
	    n = (n + rc4->state[i] + key[i % key_size]) & 255;
	    tmp = rc4->state[i];
	    rc4->state[i] = rc4->state[n];
	    rc4->state[n] = tmp;
	}
    }
}

int
cipher(RC4Con *rc4)
{
    int a, b;
    int result, tmp;
	
    a = rc4->x;
    b = rc4->y;

    a = (a + 1) & 255;
    b = (b + rc4->state[a]) & 255;
    tmp = rc4->state[a];
    result = rc4->state[b];
    rc4->state[a] = result;
    rc4->state[b] = tmp;
    result = rc4->state[ (result + tmp) & 255 ];

    rc4->x = a;
    rc4->y = b;
    return result;
}

int
usage(void)
{
    fprintf(stderr, "USAGE: cyphersaber [-d] -i infile -o outfile\n");
    exit(1);
}

int
main(int argc, char *argv[])
{
    int z,c;
    int decrypt = 0;
    char *infile, *outfile, *progname;
    FILE *in_fd, *out_fd;
    RC4Con rc4;
    char *key = NULL;
    int rand_bytes;
    char tmpkey[40];
    int tmpchar;
    int key_len;
    int write_size;
    int loop = 1;
	
    progname = argv[0];
    for (z=1; z<argc; z++)
    {
	if( !strcmp(argv[z], "-d"))
	    decrypt = 1;
	else if( !strcmp(argv[z], "-i"))
	    infile = argv[z+1];
	else if( !strcmp(argv[z], "-o"))
	    outfile = argv[z+1];
	else if( !strcmp(argv[z], "-l"))
	    loop = atoi(argv[z+1]);
    }
    if( argc < 5 )
    {
	return usage();
    }
	
    if( (in_fd = fopen(infile, "rb")) == NULL) {
	printf ("Error: Unable to open file %s\n", infile);
	exit(1);
    }

    if( (out_fd = fopen(outfile, "wb")) == NULL) {
	printf ("Error: Unable to open file %s\n", outfile);
	exit(1);
    }

    printf("Welcome to Scott's CipherSaber!\r\n");
    printf("Enter your key: ");
    fgets(tmpkey, sizeof(tmpkey), stdin);
    tmpkey[strlen(tmpkey)-1] = '\0';
    key_len = strlen(tmpkey);
	
    key = (char *)malloc(key_len + 10);
    if (decrypt)
    {
	rand_bytes = fread(key + key_len, 1, 10, in_fd);
    }
    else
    {
	FILE *f_rand;
	f_rand = fopen ("/dev/random", "rb");
	if (f_rand == NULL)
	{
	    fprintf (stderr, "Error reading from /dev/random\n");
	    exit(1);
	}
	rand_bytes = fread(key + key_len, 1, 10, f_rand);
	fclose(f_rand);
	fwrite (key + key_len, 1, 10, out_fd);
    }
    if (rand_bytes != 10)
    {
	fprintf (stderr, "Error reading initialization vector\n");
	exit(1);
    }
    memcpy (key, tmpkey, key_len);
    init(&rc4, key, key_len + 10, loop);
    memset(key, 0, key_len + 10);
    free (key);
    while ((c = fgetc(in_fd)) != EOF)
    {
	write_size = putc(c ^ cipher(&rc4), out_fd);
    }
    memset (&rc4, 0, sizeof(rc4));
    return 0;
}