/*
** Released under GPL version 2
*/

/* TODO:
** -M, -E should write 55aa sigs too
*/

#define _LARGEFILE64_SOURCE

#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdarg.h>
#include <errno.h>
#include <unistd.h>

enum {
	MAX_MBR_CODE_SIZE = 0x1b6,
	MBR_SIZE = 0x200,
	PART_DOS_EXTD = 5,
	PART_WIN_EXTD_LBA = 0xF,
	PART_LINUX_EXTD = 0x85
};

void verror_msg(const char *s, va_list p)
{
	fflush(stdout);
	fprintf(stderr, "mbinstall: ");
	vfprintf(stderr, s, p);
}

void error_msg_and_die(const char *s, ...)
{
	va_list p;
	va_start(p, s);
	verror_msg(s, p);
	va_end(p);
	putc('\n', stderr);
	exit(1);
}

void vperror_msg(const char *s, va_list p)
{
	int err = errno;
	if(s == 0) s = "";
	verror_msg(s, p);
	if (*s) s = ": ";
	fprintf(stderr, "%s%s\n", s, strerror(err));
}

void perror_msg_and_die(const char *s, ...)
{
	va_list p;
	va_start(p, s);
	vperror_msg(s, p);
	va_end(p);
	exit(1);
}

int xopen(const char *pathname, int flags)
{
	int ret;
	ret = open(pathname, flags);
	if (ret == -1)
		perror_msg_and_die("%s", pathname);
	return ret;
}

ssize_t safe_read(int fd, void *buf, size_t count)
{
	ssize_t n;
	do {
		n = read(fd, buf, count);
	} while (n < 0 && errno == EINTR);
	return n;
}

ssize_t full_read(int fd, void *buf, size_t len)
{
	ssize_t cc;
	ssize_t total;

	total = 0;
	while (len > 0) {
		cc = safe_read(fd, buf, len);
		if (cc < 0)
			return cc;      /* read() returns -1 on failure. */
		if (cc == 0)
			break;
		buf = ((char *)buf) + cc;
		total += cc;
		len -= cc;
	}

	return total;
}

void xfull_read(int fd, void *buf, size_t count)
{
	ssize_t size;
	size = full_read(fd, buf, count);
	if (size != count) {
		if(size<0)
	    		perror_msg_and_die("read error");
	    	error_msg_and_die("can't read first sector");
	}
}

ssize_t safe_write(int fd, const void *buf, size_t count)
{
	ssize_t n;
	do {
		n = write(fd, buf, count);
	} while (n < 0 && errno == EINTR);
	return n;
}

ssize_t full_write(int fd, const void *buf, size_t len)
{
	ssize_t cc;
	ssize_t total;
	total = 0;
	while (len > 0) {
		cc = safe_write(fd, buf, len);
		if (cc < 0)
			return cc;	      /* write() returns -1 on failure. */
		total += cc;
		buf = ((const char *)buf) + cc;
		len -= cc;
	}
	return total;
}

void xfull_write(int fd, const void *buf, size_t len)
{
	ssize_t size;
	size = full_write(fd, buf, len);
	if(size != len)
	    	perror_msg_and_die("write error");
}

void xclose(int fd)
{
	if(close(fd))
	    	perror_msg_and_die("error closing file (I/O error?)");
}


typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned uint32;


typedef struct sig_s {
	char	str[13];
} sig;

const sig mbr_signature = { "MultiBoot 0.4" };
const sig ext_signature = { "\r\nExtBoot 0.4" };

typedef struct part_s {
	uint8	bootable;	// +0: 0x80/0x00 - bootable/not bootable
	uint8	start_head;	// +1: head (start)
	uint16	start_cyl_sec;	// +2: cyl+sect (start)
	uint8	type;		// +4: type
	uint8	end_head;	// +5: head (end)
	uint16	end_cyl_sec;	// +6: cyl+sec (end)
	uint32	sector_ofs;	// +8: offset in sectors
	uint32	sector_size;	// +12: size in sectors
} __attribute((packed)) part;

typedef struct mbr_s {
	uint8	jmp[3];
	sig	signature;
	uint8	crlf[3];	/* 10 */
	uint8	default_char;	/* 13 */
	uint8	chars[4];	/* 14 */
	uint16	delay;		/* 18 */
	uint16	offsets[4];	/* 1a..20 */
	uint8	rest_of_code[MAX_MBR_CODE_SIZE-0x22];
	uint16	pad1;
	uint32	vol_no;
	uint16	pad2;
	part	part[4];
	uint16	bootsig;
} __attribute((packed)) mbr;

typedef struct ext_s {
	uint8	jmp[3];
	sig	signature;
	uint8	crlf[3];	/* 10 */
	uint8	default_char;	/* 13 */
	uint16	delay;		/* 14 */
	uint8	chars;		/* 16 */
	uint8	pad0;
	uint16	offset;		/* 18 */
	uint8	rest_of_code[MAX_MBR_CODE_SIZE-0x1a];
	uint16	pad1;
	uint32	vol_no;
	uint16	pad2;
	part	part[2];
	part	unused[2];
	uint16	bootsig;
} __attribute((packed)) ext;

int bad_sig_mbr;
int bad_sig_ext;

void print_mbr(const mbr *mbr_ptr)
{
	int i;
	if(memcmp(&mbr_ptr->signature, &mbr_signature, sizeof(sig))) {
		bad_sig_mbr = 1;
		puts("signature 'MultiBoot 0.4' not found in MBR");
		return;
	}
	for(i=0; i<4; i++) {
		char *label = (char*)mbr_ptr + mbr_ptr->offsets[i]-0x600;
		if((label - (char*)mbr_ptr) <= 0x80
		|| (label - (char*)mbr_ptr) >= MAX_MBR_CODE_SIZE
		|| (label+strlen(label) - (char*)mbr_ptr) >= MAX_MBR_CODE_SIZE
		) {
			printf("Warning: label %d is bad!\n", i);
			label = "BAD";
		}
		if(mbr_ptr->chars[i]) {
			printf("Partition %d: %c - %s\n", i+1,
				mbr_ptr->chars[i],
				label);
		}
	}
	for(i=0; i<4; i++) {
		if(mbr_ptr->default_char == mbr_ptr->chars[i]) {
			printf("Default choice is '%c', timeout is %d msec\n",
				mbr_ptr->default_char, mbr_ptr->delay*55);
			return;
		}
	}
	if(mbr_ptr->delay)
		puts("No default choice, no boot timeout");
	else
		puts("Blind boot selection enabled");
}

void print_ext(const ext *ext_ptr, int depth)
{
	int i;
	char *label;

	if(memcmp(&ext_ptr->signature, &ext_signature, sizeof(sig))) {
		bad_sig_ext = 1;
		printf("signature 'ExtBoot 0.4' not found"
		    " in ext partition #%d\n", depth);
		return;
	}
	label = (char*)ext_ptr + ext_ptr->offset-0x600;
	if((label - (char*)ext_ptr) <= 0x80
	|| (label - (char*)ext_ptr) >= MAX_MBR_CODE_SIZE
	|| (label+strlen(label) - (char*)ext_ptr) >= MAX_MBR_CODE_SIZE
	) {
		printf("Warning: label is bad!\n");
		label = "BAD";
	}
	if(ext_ptr->chars) {
		printf("Partition %d: %c - %s\n", depth,
			ext_ptr->chars, label);
	}
}

int type_is_ext(uint8 type)
{
	switch(type) {
	case PART_DOS_EXTD:
	case PART_WIN_EXTD_LBA:
	case PART_LINUX_EXTD:
		return 1;
	}
	return 0;
}

void xread_ext(ext* storage, int fd, uint32 sector)
{
	off64_t ofs = (off64_t)sector*0x200;
	if(ofs != lseek64(fd, ofs, SEEK_SET))
		error_msg_and_die("can't seek to %llu", (unsigned long long)ofs);
	xfull_read(fd, storage, MBR_SIZE);
}

void xwrite_ext(ext* storage, int fd, uint32 sector)
{
	off64_t ofs = (off64_t)sector*0x200;
	if(ofs != lseek64(fd, ofs, SEEK_SET))
		error_msg_and_die("can't seek to %llu", (unsigned long long)ofs);
	xfull_write(fd, storage, MBR_SIZE);
}

int walk_ext(ext* storage, uint32* ext_ofs, int maxcount, int fd, const mbr *mbr_ptr)
{
	ext* first,*cur;
	uint32 ext_base;
	int i,count;

	if(maxcount <= 0)
		return 0;

	count = 0;
	for(i=0; i<4; i++) {
		if(type_is_ext(mbr_ptr->part[i].type)) {
			ext_base = mbr_ptr->part[i].sector_ofs;
			*ext_ofs++ = ext_base;
			xread_ext(storage, fd, ext_base);
			print_ext(storage, count+5);
			count++;
			maxcount--;
		/* we ignore any extra extended partitions */
			goto found;
		}
	}
	return 0;
found:
	while(type_is_ext(storage->part[1].type)) {
		uint32 base = ext_base + storage->part[1].sector_ofs;
		if(!maxcount)
			error_msg_and_die("extended partition loop? aborting");
		*ext_ofs++ = base;
		xread_ext(++storage, fd, base);
		print_ext(storage, count+5);
		count++;
		maxcount--;
	}
	return count;
}


void usage()
{
	puts(
	"configures MultiBoot 0.4 Master Boot Record\n"
	"(a MBR with ability to interactively select which partition to boot)\n"
	"\n"
	"Usage:\n"
	"<program> [-t timeout_ms] [-d default_key] [-e default_key_for_ext_partition]\n"
	"				[-1..9 c[:label]]\n"
	"				[-M] [-E] block_device\n"
	"	-t timeout_ms	Timeout for boot menu\n"
	"	-d default_key	Key which is assumed if timeout is reached\n"
	"	-e default_key	Same for extended partition boot menu\n"
	"	-N c[:label]	key 'c' selects N'th partition (1..4 primary, 5+ extd)\n"
	"			label: on-screen label for this partition\n"
	"	-L NN:c:[:lbl]	Same as -N... for N>9\n"
	"	-M, -E		Install boot code to MBR and/or extended partition\n"
	"			(normally I would just check for their presence)\n"
	"\n"
	"	Special cases:\n"
	"	-N -		makes partition N unselectable\n"
	"	-d <bad_char>	disables timeout\n"
	"	-t 0 -d c	no menu; boot partition 'c'\n"
	"	-t 0 -d <bad>	blind boot: no menu, but you still have to press\n"
	"			a character corresponding to some partition"
	);
	exit(1);
}

/* Code images are there */
extern uint8 mboot_mbr[];
extern uint8 mboot_ext[];

#include "mboot_mbr.h"
#include "mboot_ext.h"

int main(int argc, char* argv[])
{
	mbr mbr_buf, savebuf;
	ext ext_buf[60];
	uint32 ext_ofs[60];
	int new_chars[64];
	char* new_labels[64];
	int new_delay = -1;
	int new_default_mbr = -1;
	int new_default_ext = -1;
	int replace_mbr = 0;
	int replace_ext = 0;
	int fw = -1;
	int fd;
	int i,sz,ext_cnt,dirty,cur_label_ofs;

	extern void BUG(void);
	if(sizeof(mbr) != MBR_SIZE) BUG();
	if(sizeof(ext) != MBR_SIZE) BUG();
	if(sizeof(mboot_mbr) != MAX_MBR_CODE_SIZE) BUG();
	if(sizeof(mboot_ext) != MAX_MBR_CODE_SIZE) BUG();

	for(i=0; i<64; i++) {
		new_chars[i] = -1;
		new_labels[i] = NULL;
	}

	while(1) {
		int c = getopt(argc, argv, "MEt:d:e:1:2:3:4:5:6:7:8:9:L:");
		if (c == -1)
			break;

		switch (c) {
		case 'M':
			replace_mbr = 1; break;
		case 'E':
			replace_ext = 1; break;
		case 'L':
			c = atoi(optarg);
			if(c<0 || c>63)
				error_msg_and_die("partition# must be in 0..63");
			while(optarg[0] && optarg[0]!=':')
				optarg++;
			if(optarg[0]!=':')
				error_msg_and_die("malformed -L option");
			optarg++;
			c += '1';
			/* FALLTHROUGH */			
		case '1':case '2':case '3':case '4':case '5':
		case '6':case '7':case '8':case '9':
			c -= '1';
			if(optarg[0]=='-') {
				new_chars[c] = '\0';
				new_labels[c] = "";
				break;
			}
			new_chars[c] = optarg[0];
			if(optarg[1]==':')
				new_labels[c] = optarg+2;
			break;
		case 'd':
			new_default_mbr = optarg[0];
			break;
		case 'e':
			new_default_ext = optarg[0];
			break;
		case 't':
			new_delay = (atoi(optarg)+55/2)/55;
			break;
		default:
			usage();
			break;
		}
	}
	optind--;
	argc -= optind;
	argv += optind;
	if(argc != 2)
		usage();

	fd = xopen(argv[1], O_RDONLY|O_LARGEFILE);
	xfull_read(fd, &mbr_buf, MBR_SIZE);
	if(replace_mbr) {
		memcpy(&mbr_buf, mboot_mbr, MAX_MBR_CODE_SIZE);
	} 
	print_mbr(&mbr_buf);
	savebuf = mbr_buf;
	ext_cnt = walk_ext(ext_buf, ext_ofs, 60, fd, &mbr_buf);
	if(ext_cnt && !bad_sig_ext)
		printf("Default choice for ext menu is '%c'\n",
			ext_buf[0].default_char);

	putchar('\n');

	if(bad_sig_mbr) {
		puts("No editing is done on MBR");
		goto ext;
	}

	cur_label_ofs = savebuf.offsets[0]-0x600;
	sz = 0;
	for(i=0; i<4; i++) {
		if(new_chars[i] >= 0)
			mbr_buf.chars[i] = new_chars[i];
		if(!new_labels[i])
			new_labels[i] = (char*)&savebuf + savebuf.offsets[i]-0x600;
		sz += strlen(new_labels[i])+1;
		if(cur_label_ofs + sz >= MAX_MBR_CODE_SIZE)
			error_msg_and_die("sum of label sizes exceeds %d."
				" aborting",
				MAX_MBR_CODE_SIZE);
	}
	if(new_default_mbr >= 0)
		mbr_buf.default_char = new_default_mbr;
	if(new_delay >= 0)
		mbr_buf.delay = new_delay;
	for(i=0; i<4; i++) {
		sz = strlen(new_labels[i])+1;
		mbr_buf.offsets[i] = cur_label_ofs+0x600;
		memcpy((char*)&mbr_buf + cur_label_ofs, new_labels[i], sz);
		cur_label_ofs += sz;
	}

	if(replace_mbr || memcmp(&mbr_buf, &savebuf, sizeof(mbr))) {
		puts("Rewriting MBR:");
		print_mbr(&mbr_buf);
		fw = xopen(argv[1], O_WRONLY|O_LARGEFILE);
		xfull_write(fw, &mbr_buf, MBR_SIZE);
		puts("MBR is modified");
	}

ext:
	if(!replace_ext && bad_sig_ext) {
		puts("No editing is done on extended partitions");
		goto end;
	}

	dirty = 0;
	if(replace_ext) {
		memcpy(&ext_buf[0], mboot_ext, MAX_MBR_CODE_SIZE);
		ext_buf[0].chars = 'A';
		dirty = 1;
	}
	if(new_default_ext >= 0) {
		ext_buf[0].default_char = new_default_ext;
		dirty = 1;
	}
	if(new_delay >= 0) {
		ext_buf[0].delay = new_delay;
		dirty = 1;
	}
	for(i=4; i < ext_cnt+4; i++) {
		ext* ext_ptr = &ext_buf[i-4];
		int offset;

		if(i != 4) {
			dirty = 0;
			if(replace_ext) {
				memcpy(ext_ptr, mboot_ext, MAX_MBR_CODE_SIZE);
				ext_ptr->chars = i+('A'-4);
				dirty = 1;
			}
		}
		offset = ext_ptr->offset - 0x600;

		if(new_chars[i] >= 0) {
			ext_ptr->chars = new_chars[i];
			dirty = 1;
		}
		if(new_labels[i]) {
			if(offset + strlen(new_labels[i])+1 >= MAX_MBR_CODE_SIZE)
				error_msg_and_die("sum of label sizes exceeds %d."
					" aborting",
					MAX_MBR_CODE_SIZE);
			strcpy((char*)ext_ptr + offset, new_labels[i]);
			dirty = 1;
		}
		if(dirty) {
			if(fw < 0)
				fw = xopen(argv[1], O_WRONLY|O_LARGEFILE);
			printf("Rewriting ext #%d:\n", i+1);
			print_ext(ext_ptr, i+1);
			xwrite_ext(ext_ptr, fw, ext_ofs[i-4]);
			if(i == 4)
				printf("Default choice for ext menu is '%c'\n",
					ext_ptr->default_char);
			printf("ext #%d is modified\n", i+1);
		}
	}

end:
	/* We want to get an error message if there was an I/O problem */
	if(fw >= 0)
		xclose(fw);
	return 0;
}
