diff --git a/map-reduce/word-count.c b/map-reduce/word-count.c index 5d4f2b0..8cf0b25 100644 --- a/map-reduce/word-count.c +++ b/map-reduce/word-count.c @@ -4,70 +4,92 @@ #include -#define container_of(list_ptr, container_type, member_name) \ - ({ \ - const typeof(((container_type *) 0)->member_name) *__member_ptr = \ - (list_ptr); \ - (container_type *) ((char *) __member_ptr - \ - offsetof(container_type, member_name)); \ - }) +#include +#include +#include -struct list_entry { - struct list_entry *next, *prev; +struct hlist_head { + struct hlist_node *first; }; -#define list_element(list_ptr, type, member) \ - container_of(list_ptr, type, member) +struct hlist_node { + struct hlist_node *next, **pprev; +}; -#define list_first(root_ptr, type, member) \ - list_element((root_ptr)->next, type, member) +#define HLIST_HEAD_INIT \ + { \ + .first = NULL \ + } +#define HLIST_HEAD(name) struct hlist_head name = {.first = NULL} +#define INIT_HLIST_HEAD(ptr) ((ptr)->first = NULL) -static inline struct list_entry *list_next(struct list_entry *root, - struct list_entry *current) +static inline void INIT_HLIST_NODE(struct hlist_node *h) { - if ((root == root->next) || (current->next == root)) return NULL; - return current->next; + h->next = NULL; + h->pprev = NULL; } -/* FIXME: this forbids having 2 list_for_each in the same function, because the - * variable __ptr will be defined twice, which results in a compilation error. - * The __ptr is necessary because some functions delete iter while traversing - * the list. - */ -#define list_for_each_forward(root_ptr, iter) \ - struct list_entry *__ptr; \ - for (iter = (root_ptr)->next, __ptr = (struct list_entry *) (iter)->next; \ - iter != (root_ptr); iter = (typeof((iter))) __ptr, \ - __ptr = (struct list_entry *) iter->next) - -#define list_for_each(root_ptr, iter) list_for_each_forward(root_ptr, iter) +static inline int hlist_empty(const struct hlist_head *h) +{ + return !h->first; +} -static inline void list_root_init(struct list_entry *root) +static inline void hlist_add_head(struct hlist_node *n, struct hlist_head *h) { - root->next = root->prev = root; + struct hlist_node *first = h->first; + n->next = first; + if (first) + first->pprev = &n->next; + h->first = n; + n->pprev = &h->first; } -static inline void list_add(struct list_entry *root, struct list_entry *entry) +#include + +static inline bool hlist_is_singular_node(struct hlist_node *n, + struct hlist_head *h) { - struct list_entry *prev_entry = root; - struct list_entry *next_entry = root->next; - entry->next = next_entry, entry->prev = prev_entry; - prev_entry->next = entry, next_entry->prev = entry; + return !n->next && n->pprev == &h->first; } -#define list_add_prev(root, entry) list_add((root)->prev, (entry)) +#define container_of(list_ptr, container_type, member_name) \ + ({ \ + const typeof(((container_type *) 0)->member_name) *__member_ptr = \ + (list_ptr); \ + (container_type *) ((char *) __member_ptr - \ + offsetof(container_type, member_name)); \ + }) + +#define hlist_entry(ptr, type, member) container_of(ptr, type, member) -#define list_empty(root) (root == (root)->next) +#define hlist_first_entry(head, type, member) \ + hlist_entry((head)->first, type, member) -#include -#include +#define hlist_for_each(pos, head) \ + for (pos = (head)->first; pos; pos = pos->next) + +#define hlist_entry_safe(ptr, type, member) \ + ({ \ + typeof(ptr) ____ptr = (ptr); \ + ____ptr ? hlist_entry(____ptr, type, member) : NULL; \ + }) + +#define hlist_for_each_entry(pos, head, member) \ + for (pos = hlist_entry_safe((head)->first, typeof(*(pos)), member); pos; \ + pos = hlist_entry_safe((pos)->member.next, typeof(*(pos)), member)) + +static inline void __hash_init(struct hlist_head *ht, unsigned int sz) +{ + for (unsigned int i = 0; i < sz; i++) + INIT_HLIST_HEAD(&ht[i]); +} typedef uint32_t hash_t; /* A node of the table */ struct ht_node { hash_t hash; - struct list_entry list; + struct hlist_node list; }; /* user-defined functions */ @@ -79,7 +101,7 @@ struct htable { hashfunc_t *hashfunc; cmp_t *cmp; uint32_t n_buckets; - struct list_entry *buckets; + struct hlist_head *buckets; }; /* Initializes a hash table */ @@ -90,9 +112,8 @@ static inline int ht_init(struct htable *h, { h->hashfunc = hashfunc, h->cmp = cmp; h->n_buckets = n_buckets; - h->buckets = malloc(sizeof(struct list_entry) * n_buckets); - for (size_t i = 0; i < h->n_buckets; i++) list_root_init(&h->buckets[i]); - + h->buckets = malloc(sizeof(struct hlist_head) * n_buckets); + __hash_init(h->buckets, h->n_buckets); return 0; } @@ -112,13 +133,15 @@ static inline struct ht_node *ht_find(struct htable *h, void *key) uint32_t bkt; h->hashfunc(key, &hval, &bkt); - struct list_entry *head = &h->buckets[bkt], *iter; - list_for_each (head, iter) { - struct ht_node *n = list_element(iter, struct ht_node, list); + struct hlist_head *head = &h->buckets[bkt]; + struct ht_node *n; + hlist_for_each_entry (n, head, list) { if (n->hash == hval) { int res = h->cmp(n, key); - if (!res) return n; - if (res > 0) return NULL; + if (!res) + return n; + if (res > 0) + return NULL; } } return NULL; @@ -127,6 +150,8 @@ static inline struct ht_node *ht_find(struct htable *h, void *key) /* Insert a new element with the key 'key' in the htable. * Return 0 if success. */ +#include + static inline int ht_insert(struct htable *h, struct ht_node *n, void *key) { hash_t hval; @@ -134,38 +159,51 @@ static inline int ht_insert(struct htable *h, struct ht_node *n, void *key) h->hashfunc(key, &hval, &bkt); n->hash = hval; - struct list_entry *head = &h->buckets[bkt], *iter; - list_for_each (head, iter) { - struct ht_node *tmp = list_element(iter, struct ht_node, list); + struct hlist_head *head = &h->buckets[bkt]; + struct hlist_node *iter; + hlist_for_each(iter, head) + { + struct ht_node *tmp = hlist_entry(iter, struct ht_node, list); if (tmp->hash >= hval) { int cmp = h->cmp(tmp, key); if (!cmp) /* already exist */ return -1; if (cmp > 0) { - list_add_prev(iter, &n->list); + hlist_add_head(&n->list, head); return 0; } } } - list_add_prev(head, &n->list); + hlist_add_head(&n->list, head); return 0; } static inline struct ht_node *ht_get_first(struct htable *h, uint32_t bucket) { - struct list_entry *head = &h->buckets[bucket]; - if (list_empty(head)) return NULL; - return list_first(head, struct ht_node, list); + struct hlist_head *head = &h->buckets[bucket]; + if (hlist_empty(head)) + return NULL; + return hlist_first_entry(head, struct ht_node, list); +} + +static inline struct hlist_node *hlist_next(struct hlist_head *root, + struct hlist_node *current) +{ + if ((hlist_empty(root)) || hlist_is_singular_node(current, root) || + !current) + return NULL; + return current->next; } static inline struct ht_node *ht_get_next(struct htable *h, uint32_t bucket, struct ht_node *n) { - struct list_entry *ln = list_next(&h->buckets[bucket], &n->list); - if (!ln) return NULL; - return list_element(ln, struct ht_node, list); + struct hlist_node *ln = hlist_next(&h->buckets[bucket], &n->list); + if (!ln) + return NULL; + return hlist_entry(ln, struct ht_node, list); } /* cache of words. Count the number of word using a modified hash table */ @@ -301,9 +339,11 @@ int wc_add_word(uint32_t tid, char *word, uint32_t count) struct ht_node *n; if (!(n = ht_find(&cache->htable, word))) { /* word was absent. Allocate a new wc_word */ - if (!(w = calloc(1, sizeof(struct wc_word)))) return -1; + if (!(w = calloc(1, sizeof(struct wc_word)))) + return -1; - if (count > (MAX_WORD_SIZE - 1)) w->full_word = calloc(1, count + 1); + if (count > (MAX_WORD_SIZE - 1)) + w->full_word = calloc(1, count + 1); wc_strncpy(GET_WORD(w), word, count); @@ -346,7 +386,8 @@ int wc_merge_results(uint32_t tid, uint32_t n_threads) uint32_t n_workers; /* Keep the number of workers <= nbthread */ if (n_threads > n_buckets) { - if (tid > n_buckets - 1) return 0; + if (tid > n_buckets - 1) + return 0; n_workers = n_buckets; } else n_workers = n_threads; @@ -358,7 +399,8 @@ int wc_merge_results(uint32_t tid, uint32_t n_threads) uint32_t wk_bstart = wk_buckets * tid, wk_bend = wk_bstart + wk_buckets; /* last thread must also do last buckets */ - if ((tid == (n_workers - 1))) wk_bend += n_buckets % n_workers; + if ((tid == (n_workers - 1))) + wk_bend += n_buckets % n_workers; for (size_t i = 0; i < n_threads; i++) { struct wc_cache *cache = &thread_caches[i]; @@ -389,7 +431,8 @@ int wc_print(int id) bkt_total++, total++; count_total += w->counter; } - if (!bkt_total) empty_bkt++; + if (!bkt_total) + empty_bkt++; bkt_total = 0; } printf("Words: %d, word counts: %d, full buckets: %d (ideal %d)\n", total, @@ -402,10 +445,12 @@ static int __wc_destroy(struct wc_cache *wcc, int id) int valid = (id == -1); for (uint32_t j = 0; j < n_buckets; j++) { struct ht_node *iter = ht_get_first(&wcc->htable, j); - for (; iter; iter = ht_get_next(&wcc->htable, j, iter)) { + struct ht_node *tmp = ht_get_next(&wcc->htable, j, iter); + for (; tmp; iter = tmp, tmp = ht_get_next(&wcc->htable, j, tmp)) { struct wc_word *w = valid ? container_of(iter, struct wc_word, node_main) : container_of(iter, struct wc_word, node); + free(w->full_word); free(w); } @@ -418,12 +463,15 @@ static int __wc_destroy(struct wc_cache *wcc, int id) int wc_destroy(uint32_t n_threads) { for (size_t i = 0; i < n_threads; i++) { - if (__wc_destroy(&thread_caches[i], i)) return -1; - if (ht_destroy(&thread_caches[i].htable)) return -1; + if (__wc_destroy(&thread_caches[i], i)) + return -1; + if (ht_destroy(&thread_caches[i].htable)) + return -1; } free(thread_caches); - if (ht_destroy(&main_cache.htable)) return -1; + if (ht_destroy(&main_cache.htable)) + return -1; return 0; } @@ -467,7 +515,8 @@ static inline int fa_init(char *file, uint32_t n_threads, off_t *fsz) } file_content = mmap(NULL, file_size, PROT_READ, MMAP_FLAGS, fd, 0); - if (file_content == MAP_FAILED) file_content = NULL; + if (file_content == MAP_FAILED) + file_content = NULL; *fsz = file_size; return 0; @@ -478,7 +527,8 @@ static inline int fa_init(char *file, uint32_t n_threads, off_t *fsz) */ static inline int fa_read_init() { - if (!file_content && !(worker_buffer = malloc(BUFFER_SIZE))) return -1; + if (!file_content && !(worker_buffer = malloc(BUFFER_SIZE))) + return -1; return 0; } @@ -537,9 +587,11 @@ int mr_init(void) return -1; } - if (fa_init(file_name, n_threads, &file_size)) return -1; + if (fa_init(file_name, n_threads, &file_size)) + return -1; - if (wc_init(n_threads, file_size / MAX_WORD_SIZE)) return -1; + if (wc_init(n_threads, file_size / MAX_WORD_SIZE)) + return -1; return 0; } @@ -551,9 +603,11 @@ int mr_destroy(void) return -1; } - if (fa_init(file_name, n_threads, &file_size)) return -1; + if (fa_init(file_name, n_threads, &file_size)) + return -1; - if (wc_destroy(n_threads)) return -1; + if (wc_destroy(n_threads)) + return -1; return 0; } @@ -598,7 +652,8 @@ static inline int add_sep(uint32_t tid) { if (count) { word[count] = '\0'; /* Add current word */ - if (wc_add_word(tid, word, count)) return -1; + if (wc_add_word(tid, word, count)) + return -1; count = 0; } return 0; @@ -625,13 +680,15 @@ static int buff_proceed(uint32_t tid, char *buff, size_t size, char last) /* Configure the buffer slices of each worker */ static int buff_init(uint32_t tid) { - if (fa_read_init()) return -1; + if (fa_read_init()) + return -1; worker_slice = file_size / n_threads; worker_current = worker_slice * tid; /* Last thread handle remaining bytes */ - if (tid == (n_threads - 1)) worker_slice += file_size % n_threads; + if (tid == (n_threads - 1)) + worker_slice += file_size % n_threads; off_t worker_end = worker_current + worker_slice; @@ -640,18 +697,24 @@ static int buff_init(uint32_t tid) */ char *buff; do { - if (tid == 0) break; - if (fa_read(tid, &buff, 1, worker_current) != 1) return -1; - if (!IS_LETTER(*buff)) break; + if (tid == 0) + break; + if (fa_read(tid, &buff, 1, worker_current) != 1) + return -1; + if (!IS_LETTER(*buff)) + break; worker_current++; worker_slice--; } while (*buff); /* add letters of the last word if we are not the last thread */ do { - if (tid == (n_threads - 1)) break; - if (fa_read(tid, &buff, 1, worker_end) != 1) return -1; - if (!IS_LETTER(*buff)) break; + if (tid == (n_threads - 1)) + break; + if (fa_read(tid, &buff, 1, worker_end) != 1) + return -1; + if (!IS_LETTER(*buff)) + break; worker_end++, worker_slice++; } while (*buff); @@ -661,21 +724,25 @@ static int buff_init(uint32_t tid) static int buff_destroy() { free(word); - if (fa_read_destroy()) return -1; + if (fa_read_destroy()) + return -1; return 0; } /* Read a buffer from the file */ static int buff_read(uint32_t tid, char **buff, off_t *size, char *last) { - if (!worker_slice) return 0; + if (!worker_slice) + return 0; off_t size_read = fa_read(tid, buff, worker_slice, worker_current); - if (size_read == -1) return -1; + if (size_read == -1) + return -1; *size = size_read; worker_current += size_read, worker_slice -= size_read; - if (!worker_slice) *last = 1; + if (!worker_slice) + *last = 1; return 0; } @@ -683,19 +750,23 @@ void *mr_map(void *id) { uint32_t tid = ((struct thread_info *) id)->thread_num; int ret = buff_init(tid); - if (ret) goto bail; + if (ret) + goto bail; char *buff; off_t size = 0; char last = 0; while (!(ret = buff_read(tid, &buff, &size, &last))) { - if (!size) break; - if (buff_proceed(tid, buff, size, last)) goto bail; + if (!size) + break; + if (buff_proceed(tid, buff, size, last)) + goto bail; if (last) /* If this was the last buffer */ break; } - if (buff_destroy()) goto bail; + if (buff_destroy()) + goto bail; /* wait for other worker before merging */ if (pthread_barrier_wait(&barrier) > 0) { @@ -721,13 +792,16 @@ static struct thread_info *tinfo; static int parse_args(int argc, char **argv) { - if (argc < 3) return -1; + if (argc < 3) + return -1; file_name = argv[1]; - if (!file_name) return -1; + if (!file_name) + return -1; n_threads = atoi(argv[2]); - if (!n_threads) return -1; + if (!n_threads) + return -1; return 0; } @@ -746,7 +820,8 @@ static void run_threads(void) static void wait_threads() { for (size_t i = 0; i < n_threads; i++) - if (pthread_join(tinfo[i].thread_id, NULL)) throw_err("thread join"); + if (pthread_join(tinfo[i].thread_id, NULL)) + throw_err("thread join"); free(tinfo); } @@ -767,18 +842,22 @@ int main(int argc, char **argv) } double start = now(); - if (mr_init()) exit(EXIT_FAILURE); + if (mr_init()) + exit(EXIT_FAILURE); run_threads(); wait_threads(); - if (mr_reduce()) exit(EXIT_FAILURE); + if (mr_reduce()) + exit(EXIT_FAILURE); /* Done here, to avoid counting the printing */ double end = now(); - if (mr_print()) exit(EXIT_FAILURE); - if (mr_destroy()) exit(EXIT_FAILURE); + if (mr_print()) + exit(EXIT_FAILURE); + if (mr_destroy()) + exit(EXIT_FAILURE); printf("Done in %g msec\n", end - start);