hooking.c (7549B)
1 #include <dlfcn.h> 2 #include <err.h> 3 #include <gnu/lib-names.h> 4 #include <mpi.h> 5 #include <stdlib.h> 6 #include <signal.h> 7 #include <string.h> 8 #include <unistd.h> 9 #include <stdio.h> 10 #include <time.h> 11 12 #include "callout.h" 13 #include "hooking.h" 14 #include "parser.h" 15 16 #define SMASH_GRAPH 0x1234 17 18 timer_t smash_timer_id; 19 unsigned int smash_my_rank; 20 int smash_dead, smash_world_size, smash_alarm; 21 22 struct cfg_delays *smash_delays; 23 struct cfg_failures *smash_failures; 24 25 struct smash_graph_msg { 26 int src, dst; 27 }; 28 29 struct smash_graph_msgs { 30 size_t i; 31 struct smash_graph_msg msgs[4096]; 32 } smash_graph_msgs; 33 34 static int master_done = 0; 35 36 int 37 smash_failure(void) 38 { 39 int buf; 40 MPI_Status status; 41 size_t recv = 0; 42 int (*f)(); 43 44 smash_dead = 1; 45 f = smash_get_lib_func(LIBMPI, "MPI_Recv"); 46 while (recv != smash_world_size - smash_failures->size) { 47 f(&buf, 1, MPI_INT, MPI_ANY_SOURCE, 0xdead, MPI_COMM_WORLD, &status); 48 recv++; 49 } 50 MPI_Finalize(); 51 exit(0); 52 } 53 54 int 55 MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag, 56 MPI_Comm comm, MPI_Status *status) { 57 int (*f)(), res; 58 59 f = smash_get_lib_func(LIBMPI, "MPI_Recv"); 60 61 while (1) { 62 res = f(buf, count, datatype, source, tag, comm, status); 63 if (status->MPI_TAG != 0xdead || status->MPI_TAG != SMASH_GRAPH) 64 break; 65 bzero(status, sizeof(MPI_Status)); 66 master_done = status->MPI_TAG == SMASH_GRAPH; 67 } 68 69 smash_graph_msgs.msgs[smash_graph_msgs.i].src = status->MPI_SOURCE; 70 smash_graph_msgs.msgs[smash_graph_msgs.i].dst = smash_my_rank; 71 smash_graph_msgs.i++; 72 73 return res; 74 } 75 76 void * 77 smash_get_lib_func(const char *lname, const char *fname) 78 { 79 void *lib, *p; 80 81 if (!(lib = dlopen(lname, RTLD_LAZY))) 82 errx(EXIT_FAILURE, "%s", dlerror()); 83 84 if (!(p = dlsym(lib, fname))) 85 errx(EXIT_FAILURE, "%s", dlerror()); 86 87 dlclose(lib); 88 return p; 89 } 90 91 static void 92 smash_handler(__attribute__((unused)) int signum) 93 { 94 smash_clock(); 95 } 96 97 timer_t 98 smash_setup_alarm(void) 99 { 100 timer_t timerid; 101 struct sigaction sa; 102 struct sigevent sev; 103 104 sa.sa_handler = smash_handler; 105 sigemptyset(&sa.sa_mask); 106 sa.sa_flags = SA_RESTART; 107 sigaction(SIGALRM, &sa, NULL); 108 109 sev.sigev_notify = SIGEV_SIGNAL; 110 sev.sigev_signo = SIGALRM; 111 sev.sigev_value.sival_ptr = &timerid; 112 if (timer_create(CLOCK_REALTIME, &sev, &timerid) < 0) 113 errx(1, "timer_create"); 114 115 return timerid; 116 } 117 118 int 119 __libc_start_main( 120 int (*main)(int, char **, char **), 121 int argc, 122 char **argv, 123 int (*init)(int, char **, char **), 124 void (*fini)(void), 125 void (*rtld_fini)(void), 126 void *stack_end) 127 { 128 int (*f)(); 129 130 if (smash_parse_cfg(CFG_DELAY, (void **)&smash_delays) < 0) 131 errx(EXIT_FAILURE, "error in CFG_DELAY\n"); 132 133 if (smash_parse_cfg(CFG_FAILURE, (void **)&smash_failures) < 0) 134 errx(EXIT_FAILURE, "error in CFG_FAILURE\n"); 135 136 f = smash_get_lib_func(LIBSTD, "__libc_start_main"); 137 smash_alarm = 0; 138 smash_dead = 0; 139 return f(main, argc, argv, init, fini, rtld_fini, stack_end); 140 } 141 142 int 143 MPI_Init(int *argc, char ***argv) 144 { 145 unsigned int i; 146 int (*f)(int *, char ***), res, rank; 147 148 if (!smash_alarm) { 149 smash_timer_id = smash_setup_alarm(); 150 smash_alarm = 1; 151 } 152 153 smash_graph_msgs.i = 0; 154 155 f = smash_get_lib_func(LIBMPI, "MPI_Init"); 156 res = f(argc, argv); 157 MPI_Comm_rank(MPI_COMM_WORLD, &rank); 158 smash_my_rank = rank; 159 MPI_Comm_size(MPI_COMM_WORLD, &smash_world_size); 160 161 if (smash_failures != NULL) { 162 for (i = 0; i < smash_failures->size; ++i) { 163 if (smash_failures->failures[i].node == smash_my_rank) { 164 smash_timeout(smash_failure, 0, smash_failures->failures[i].time, NULL); 165 } 166 } 167 } 168 return res; 169 } 170 171 int 172 save_graph(struct smash_graph_msgs *m) 173 { 174 FILE *fs; 175 size_t i; 176 char *filepath; 177 178 filepath = getenv("SMASH_MPI_GRAPH"); 179 if (!filepath) 180 filepath = "graph.dot"; 181 182 if (!(fs = fopen(filepath, "w+"))) 183 return -1; 184 185 fprintf(fs, "digraph SMASH_MPI {\n layout=twopi\n ranksep=3;\n ratio=auto;\n"); 186 for (i = 0; i < m->i; ++i) { 187 fprintf(fs, "\"p%d\" -> \"p%d\" [ color=\"purple\" ];\n", 188 m->msgs[i].src, 189 m->msgs[i].dst); 190 } 191 fprintf(fs, "}"); 192 fflush(fs); 193 return 0; 194 } 195 196 int 197 MPI_Finalize(void) 198 { 199 int (*f)(void); 200 size_t i, j; 201 int (*ssend)(); 202 int (*recv)(); 203 204 recv = smash_get_lib_func(LIBMPI, "MPI_Recv"); 205 ssend = smash_get_lib_func(LIBMPI, "MPI_Ssend"); 206 207 if (smash_failures != NULL) { 208 if (!smash_dead) { 209 for (i = 0; i < smash_failures->size; i++) 210 ssend(&smash_world_size, 1, MPI_INT, smash_failures->failures[i].node, 0xdead, MPI_COMM_WORLD); 211 } 212 } 213 214 int done; 215 if (smash_my_rank == 0) { 216 struct smash_graph_msgs tmp = {0}; 217 MPI_Status status; 218 for (i = 1; i < (unsigned int)smash_world_size; ++i) { 219 done = 1; 220 ssend(&done, 1, MPI_INT, i, SMASH_GRAPH, MPI_COMM_WORLD, &status); 221 recv(&tmp, sizeof(struct smash_graph_msgs), MPI_CHAR, 222 i, SMASH_GRAPH, MPI_COMM_WORLD, 223 &status); 224 225 for (j = 0; j < tmp.i; ++j) { 226 smash_graph_msgs.msgs[smash_graph_msgs.i].src = tmp.msgs[j].src; 227 smash_graph_msgs.msgs[smash_graph_msgs.i].dst = tmp.msgs[j].dst; 228 smash_graph_msgs.i++; 229 } 230 } 231 /* Output graph */ 232 save_graph(&smash_graph_msgs); 233 } else { 234 if (!master_done) 235 recv(&done, 1, MPI_INT, 0, SMASH_GRAPH, MPI_COMM_WORLD); 236 ssend(&smash_graph_msgs, sizeof(struct smash_graph_msgs), 237 MPI_CHAR, 0, SMASH_GRAPH, MPI_COMM_WORLD); 238 } 239 240 free(smash_delays); 241 free(smash_failures); 242 f = smash_get_lib_func(LIBMPI, "MPI_Finalize"); 243 return f(); 244 } 245 246 int 247 MPI_Ssend(const void *buf, int count, MPI_Datatype datatype, int dest, 248 int tag, MPI_Comm comm) 249 { 250 int (*f)(); 251 unsigned int i; 252 struct mpi_send_args args = { 253 .count = count, 254 .datatype = datatype, 255 .dest = dest, 256 .tag = tag, 257 .comm = comm, 258 }; 259 args.buf = malloc(sizeof(buf) * count); 260 memcpy(args.buf, buf, sizeof(buf) * count); 261 262 f = smash_get_lib_func(LIBMPI, "MPI_Ssend"); 263 264 for (i = 0; i < smash_delays->size; ++i) { 265 /* If a delay in the config file matches our rank and the target rank, inject it in the callout struct. */ 266 if (smash_delays->delays[i].dst == (unsigned int)dest && 267 smash_delays->delays[i].src == smash_my_rank && 268 (smash_delays->delays[i].msg > 0 || 269 smash_delays->delays[i].msg == -1)) { 270 sem_wait(smash_timeout(f, 6, smash_delays->delays[i].delay, &args)); 271 smash_delays->delays[i].msg -= 1 * (smash_delays->delays[i].msg != -1); 272 return 0; 273 } 274 } 275 /* If there is no delay to apply, call MPI_Ssend directly. */ 276 return f(buf, count, datatype, dest, tag, comm); 277 } 278 279 int 280 MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest, 281 int tag, MPI_Comm comm) 282 { 283 int (*f)(); 284 unsigned int i; 285 struct mpi_send_args args = { 286 .count = count, 287 .datatype = datatype, 288 .dest = dest, 289 .tag = tag, 290 .comm = comm, 291 }; 292 args.buf = malloc(sizeof(buf) * count); 293 memcpy(args.buf, buf, sizeof(buf) * count); 294 295 f = smash_get_lib_func(LIBMPI, "MPI_Send"); 296 297 for (i = 0; i < smash_delays->size; ++i) { 298 /* If a delay in the config file matches our rank and the target rank, inject it in the callout struct. */ 299 if (smash_delays->delays[i].dst == (unsigned int)dest && 300 smash_delays->delays[i].src == smash_my_rank && 301 (smash_delays->delays[i].msg > 0 || 302 smash_delays->delays[i].msg == -1)) { 303 smash_timeout(f, 6, smash_delays->delays[i].delay, &args); 304 smash_delays->delays[i].msg -= 1 * (smash_delays->delays[i].msg != -1); 305 return 0; 306 } 307 } 308 /* If there is no delay to apply, call MPI_Send directly. */ 309 return f(buf, count, datatype, dest, tag, comm); 310 }