smash-MPI

A runtime library for injecting faults and delays into Open MPI.
git clone git@git.mpah.dev/smash-MPI.git
Log | Files | Refs | README | LICENSE

hooking.c (7549B)


      1 #include <dlfcn.h>
      2 #include <err.h>
      3 #include <gnu/lib-names.h>
      4 #include <mpi.h>
      5 #include <stdlib.h>
      6 #include <signal.h>
      7 #include <string.h>
      8 #include <unistd.h>
      9 #include <stdio.h>
     10 #include <time.h>
     11 
     12 #include "callout.h"
     13 #include "hooking.h"
     14 #include "parser.h"
     15 
     16 #define SMASH_GRAPH 0x1234
     17 
     18 timer_t smash_timer_id;
     19 unsigned int smash_my_rank;
     20 int smash_dead, smash_world_size, smash_alarm;
     21 
     22 struct cfg_delays *smash_delays;
     23 struct cfg_failures *smash_failures;
     24 
     25 struct smash_graph_msg {
     26 	int src, dst;
     27 };
     28 
     29 struct smash_graph_msgs {
     30 	size_t i;
     31 	struct smash_graph_msg msgs[4096];
     32 } smash_graph_msgs;
     33 
     34 static int master_done = 0;
     35 
     36 int
     37 smash_failure(void)
     38 {
     39 	int buf;
     40 	MPI_Status status;
     41 	size_t recv = 0;
     42 	int (*f)();
     43 
     44 	smash_dead = 1;
     45 	f = smash_get_lib_func(LIBMPI, "MPI_Recv");
     46 	while (recv != smash_world_size - smash_failures->size) {
     47 		f(&buf, 1, MPI_INT, MPI_ANY_SOURCE, 0xdead, MPI_COMM_WORLD, &status);
     48 		recv++;
     49 	}
     50 	MPI_Finalize();
     51 	exit(0);
     52 }
     53 
     54 int
     55 MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag,
     56              MPI_Comm comm, MPI_Status *status) {
     57 	int (*f)(), res;
     58 
     59 	f = smash_get_lib_func(LIBMPI, "MPI_Recv");
     60 
     61 	while (1) {
     62 		res = f(buf, count, datatype, source, tag, comm, status);
     63 		if (status->MPI_TAG != 0xdead || status->MPI_TAG != SMASH_GRAPH)
     64 			break;
     65 		bzero(status, sizeof(MPI_Status));
     66 		master_done = status->MPI_TAG == SMASH_GRAPH;
     67 	}
     68 
     69 	smash_graph_msgs.msgs[smash_graph_msgs.i].src = status->MPI_SOURCE;
     70 	smash_graph_msgs.msgs[smash_graph_msgs.i].dst = smash_my_rank;
     71 	smash_graph_msgs.i++;
     72 
     73 	return res;
     74 }
     75 
     76 void *
     77 smash_get_lib_func(const char *lname, const char *fname)
     78 {
     79 	void *lib, *p;
     80 
     81 	if (!(lib = dlopen(lname, RTLD_LAZY)))
     82 		errx(EXIT_FAILURE, "%s", dlerror());
     83 
     84 	if (!(p = dlsym(lib, fname)))
     85 		errx(EXIT_FAILURE, "%s", dlerror());
     86 
     87 	dlclose(lib);
     88 	return p;
     89 }
     90 
     91 static void
     92 smash_handler(__attribute__((unused)) int signum)
     93 {
     94 	smash_clock();
     95 }
     96 
     97 timer_t
     98 smash_setup_alarm(void)
     99 {
    100 	timer_t timerid;
    101 	struct sigaction sa;
    102 	struct sigevent sev;
    103 
    104 	sa.sa_handler = smash_handler;
    105 	sigemptyset(&sa.sa_mask);
    106 	sa.sa_flags = SA_RESTART;
    107 	sigaction(SIGALRM, &sa, NULL);
    108 
    109 	sev.sigev_notify = SIGEV_SIGNAL;
    110 	sev.sigev_signo = SIGALRM;
    111 	sev.sigev_value.sival_ptr = &timerid;
    112 	if (timer_create(CLOCK_REALTIME, &sev, &timerid) < 0)
    113 		errx(1, "timer_create");
    114 
    115 	return timerid;
    116 }
    117 
    118 int
    119 __libc_start_main(
    120 	int (*main)(int, char **, char **),
    121 	int argc,
    122 	char **argv,
    123 	int (*init)(int, char **, char **),
    124 	void (*fini)(void),
    125 	void (*rtld_fini)(void),
    126 	void *stack_end)
    127 {
    128 	int (*f)();
    129 
    130 	if (smash_parse_cfg(CFG_DELAY, (void **)&smash_delays) < 0)
    131 		errx(EXIT_FAILURE, "error in CFG_DELAY\n");
    132 
    133 	if (smash_parse_cfg(CFG_FAILURE, (void **)&smash_failures) < 0)
    134 		errx(EXIT_FAILURE, "error in CFG_FAILURE\n");
    135 
    136 	f = smash_get_lib_func(LIBSTD, "__libc_start_main");
    137 	smash_alarm = 0;
    138 	smash_dead = 0;
    139 	return f(main, argc, argv, init, fini, rtld_fini, stack_end);
    140 }
    141 
    142 int
    143 MPI_Init(int *argc, char ***argv)
    144 {
    145 	unsigned int i;
    146 	int (*f)(int *, char ***), res, rank;
    147 
    148         if (!smash_alarm) {
    149 		smash_timer_id = smash_setup_alarm();
    150 		smash_alarm = 1;
    151 	}
    152 
    153 	smash_graph_msgs.i = 0;
    154 
    155 	f = smash_get_lib_func(LIBMPI, "MPI_Init");
    156 	res = f(argc, argv);
    157 	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    158 	smash_my_rank = rank;
    159 	MPI_Comm_size(MPI_COMM_WORLD, &smash_world_size);
    160 
    161 	if (smash_failures != NULL) {
    162 		for (i = 0; i < smash_failures->size; ++i) {
    163 			if (smash_failures->failures[i].node == smash_my_rank) {
    164 				smash_timeout(smash_failure, 0, smash_failures->failures[i].time, NULL);
    165 			}
    166 		}
    167 	}
    168 	return res;
    169 }
    170 
    171 int
    172 save_graph(struct smash_graph_msgs *m)
    173 {
    174 	FILE *fs;
    175 	size_t i;
    176 	char *filepath;
    177 
    178 	filepath = getenv("SMASH_MPI_GRAPH");
    179 	if (!filepath)
    180 		filepath = "graph.dot";
    181 
    182 	if (!(fs = fopen(filepath, "w+")))
    183 		return -1;
    184 
    185 	fprintf(fs, "digraph SMASH_MPI {\n layout=twopi\n ranksep=3;\n ratio=auto;\n");
    186 	for (i = 0; i < m->i; ++i) {
    187 		fprintf(fs, "\"p%d\" -> \"p%d\" [ color=\"purple\" ];\n",
    188 		       m->msgs[i].src,
    189 		       m->msgs[i].dst);
    190 	}
    191 	fprintf(fs, "}");
    192 	fflush(fs);
    193 	return 0;
    194 }
    195 
    196 int
    197 MPI_Finalize(void)
    198 {
    199 	int (*f)(void);
    200 	size_t i, j;
    201 	int (*ssend)();
    202 	int (*recv)();
    203 
    204 	recv = smash_get_lib_func(LIBMPI, "MPI_Recv");
    205 	ssend = smash_get_lib_func(LIBMPI, "MPI_Ssend");
    206 
    207 	if (smash_failures != NULL) {
    208 		if (!smash_dead) {
    209 			for (i = 0; i < smash_failures->size; i++)
    210 				ssend(&smash_world_size, 1, MPI_INT, smash_failures->failures[i].node, 0xdead, MPI_COMM_WORLD);
    211 		}
    212 	}
    213 
    214 	int done;
    215 	if (smash_my_rank == 0) {
    216 		struct smash_graph_msgs tmp = {0};
    217 		MPI_Status status;
    218 		for (i = 1; i < (unsigned int)smash_world_size; ++i) {
    219 			done = 1;
    220 			ssend(&done, 1, MPI_INT, i, SMASH_GRAPH, MPI_COMM_WORLD, &status);
    221 			recv(&tmp, sizeof(struct smash_graph_msgs), MPI_CHAR,
    222 			     i, SMASH_GRAPH, MPI_COMM_WORLD,
    223 			     &status);
    224 
    225 			for (j = 0; j < tmp.i; ++j) {
    226 				smash_graph_msgs.msgs[smash_graph_msgs.i].src = tmp.msgs[j].src;
    227 				smash_graph_msgs.msgs[smash_graph_msgs.i].dst = tmp.msgs[j].dst;
    228 				smash_graph_msgs.i++;
    229 			}
    230 		}
    231 		/* Output graph */
    232 		save_graph(&smash_graph_msgs);
    233 	} else {
    234 		if (!master_done)
    235 			recv(&done, 1, MPI_INT, 0, SMASH_GRAPH, MPI_COMM_WORLD);
    236 		ssend(&smash_graph_msgs, sizeof(struct smash_graph_msgs),
    237 		     MPI_CHAR, 0, SMASH_GRAPH, MPI_COMM_WORLD);
    238 	}
    239 
    240 	free(smash_delays);
    241 	free(smash_failures);
    242 	f = smash_get_lib_func(LIBMPI, "MPI_Finalize");
    243 	return f();
    244 }
    245 
    246 int
    247 MPI_Ssend(const void *buf, int count, MPI_Datatype datatype, int dest,
    248              int tag, MPI_Comm comm)
    249 {
    250 	int (*f)();
    251 	unsigned int i;
    252 	struct mpi_send_args args = {
    253 		.count = count,
    254 		.datatype = datatype,
    255 		.dest = dest,
    256 		.tag = tag,
    257 		.comm = comm,
    258 	};
    259 	args.buf = malloc(sizeof(buf) * count);
    260 	memcpy(args.buf, buf, sizeof(buf) * count);
    261 
    262 	f = smash_get_lib_func(LIBMPI, "MPI_Ssend");
    263 
    264 	for (i = 0; i < smash_delays->size; ++i) {
    265 		/* If a delay in the config file matches our rank and the target rank, inject it in the callout struct. */
    266                 if (smash_delays->delays[i].dst == (unsigned int)dest &&
    267                     smash_delays->delays[i].src == smash_my_rank &&
    268 		    (smash_delays->delays[i].msg > 0 ||
    269 		     smash_delays->delays[i].msg == -1)) {
    270                         sem_wait(smash_timeout(f, 6, smash_delays->delays[i].delay, &args));
    271 			smash_delays->delays[i].msg -= 1 * (smash_delays->delays[i].msg != -1);
    272 			return 0;
    273                 }
    274         }
    275 	/* If there is no delay to apply, call MPI_Ssend directly. */
    276 	return f(buf, count, datatype, dest, tag, comm);
    277 }
    278 
    279 int
    280 MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest,
    281              int tag, MPI_Comm comm)
    282 {
    283 	int (*f)();
    284 	unsigned int i;
    285 	struct mpi_send_args args = {
    286 		.count = count,
    287 		.datatype = datatype,
    288 		.dest = dest,
    289 		.tag = tag,
    290 		.comm = comm,
    291 	};
    292 	args.buf = malloc(sizeof(buf) * count);
    293 	memcpy(args.buf, buf, sizeof(buf) * count);
    294 
    295 	f = smash_get_lib_func(LIBMPI, "MPI_Send");
    296 
    297 	for (i = 0; i < smash_delays->size; ++i) {
    298 		/* If a delay in the config file matches our rank and the target rank, inject it in the callout struct. */
    299                 if (smash_delays->delays[i].dst == (unsigned int)dest &&
    300                     smash_delays->delays[i].src == smash_my_rank &&
    301 		    (smash_delays->delays[i].msg > 0 ||
    302 		     smash_delays->delays[i].msg == -1)) {
    303                         smash_timeout(f, 6, smash_delays->delays[i].delay, &args);
    304 			smash_delays->delays[i].msg -= 1 * (smash_delays->delays[i].msg != -1);
    305 			return 0;
    306                 }
    307         }
    308 	/* If there is no delay to apply, call MPI_Send directly. */
    309 	return f(buf, count, datatype, dest, tag, comm);
    310 }