Skip to content

Commit 1b93a17

Browse files
authored
Merge pull request #7149 from bosilca/fix/datatype_overflow
Prevent overflow when dealing with datatype count.
2 parents 9caf43a + 3de636d commit 1b93a17

14 files changed

+41
-32
lines changed

opal/datatype/opal_convertor.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ opal_convertor_create_stack_at_begining( opal_convertor_t* convertor,
403403
pStack[1].count = pElems[0].loop.loops;
404404
pStack[1].type = OPAL_DATATYPE_LOOP;
405405
} else {
406-
pStack[1].count = pElems[0].elem.count * pElems[0].elem.blocklen;
406+
pStack[1].count = (size_t)pElems[0].elem.count * pElems[0].elem.blocklen;
407407
pStack[1].type = pElems[0].elem.common.type;
408408
}
409409
return OPAL_SUCCESS;

opal/datatype/opal_convertor_raw.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ opal_convertor_raw( opal_convertor_t* pConvertor,
125125
if( pElem->elem.common.flags & OPAL_DATATYPE_FLAG_DATA ) {
126126
const ddt_elem_desc_t* current = &(pElem->elem);
127127

128-
if( count_desc != (current->count * current->blocklen) ) { /* Not the full element description */
128+
if( count_desc != ((size_t)current->count * current->blocklen) ) { /* Not the full element description */
129129
if( (do_now = count_desc % current->blocklen) ) {
130130
do_now = current->blocklen - do_now; /* how much left in the block */
131131
source_base += current->disp;
@@ -152,7 +152,7 @@ opal_convertor_raw( opal_convertor_t* pConvertor,
152152
source_base += current->disp;
153153

154154
do_now = current->count;
155-
if( count_desc != (current->count * current->blocklen) ) {
155+
if( count_desc != ((size_t)current->count * current->blocklen) ) {
156156
do_now = count_desc / current->blocklen;
157157
assert( 0 == (count_desc % current->blocklen) );
158158
}

opal/datatype/opal_datatype.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ opal_datatype_create_from_packed_description( void** packed_buffer,
376376
* argument, the number of bytes of the gap at the beginning.
377377
*/
378378
static inline ptrdiff_t
379-
opal_datatype_span( const opal_datatype_t* pData, int64_t count,
379+
opal_datatype_span( const opal_datatype_t* pData, size_t count,
380380
ptrdiff_t* gap)
381381
{
382382
if (OPAL_UNLIKELY(0 == pData->size) || (0 == count)) {
@@ -386,7 +386,7 @@ opal_datatype_span( const opal_datatype_t* pData, int64_t count,
386386
*gap = pData->true_lb;
387387
ptrdiff_t extent = (pData->ub - pData->lb);
388388
ptrdiff_t true_extent = (pData->true_ub - pData->true_lb);
389-
return true_extent + (count - 1) * extent;
389+
return true_extent + extent * (count - 1);
390390
}
391391

392392
#if OPAL_ENABLE_DEBUG

opal/datatype/opal_datatype_add.c

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ int32_t opal_datatype_add( opal_datatype_t* pdtBase, const opal_datatype_t* pdtA
285285
pLast->elem.common.flags = pdtAdd->flags & ~(OPAL_DATATYPE_FLAG_COMMITTED);
286286
pLast->elem.common.type = pdtAdd->id;
287287
pLast->elem.disp = disp;
288-
pLast->elem.extent = count * extent;
288+
pLast->elem.extent = (ptrdiff_t)count * extent;
289289
/* assume predefined datatypes without extent, aka. contiguous */
290290
pLast->elem.count = 1;
291291
pLast->elem.blocklen = count;
@@ -328,9 +328,18 @@ int32_t opal_datatype_add( opal_datatype_t* pdtBase, const opal_datatype_t* pdtA
328328
pLast->elem.count = count;
329329
pLast->elem.extent = extent;
330330
}
331-
} else if( extent == (ptrdiff_t)(pLast->elem.count * pLast->elem.extent) ) {
331+
} else if( extent == ((ptrdiff_t)pLast->elem.count * pLast->elem.extent) ) {
332332
/* It's just a repetition of the same element, increase the count */
333-
pLast->elem.count *= count;
333+
/* We need to protect against the case where the multiplication below results in a
334+
* number larger than the max uint32_t. In the unlikely situation where that's the case
335+
* we should not try to optimize the item further but instead fall back and build a loop
336+
* around it.
337+
*/
338+
uint32_t cnt = pLast->elem.count * count;
339+
if( cnt < pLast->elem.count ) {
340+
goto build_loop;
341+
}
342+
pLast->elem.count = cnt; /* we're good, merge the elements */
334343
} else {
335344
/* No luck here, no optimization can be applied. Fall back to the
336345
* normal case where we add a loop around the datatype.

opal/datatype/opal_datatype_copy.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ static inline int32_t _copy_content_same_ddt( const opal_datatype_t* datatype, i
128128
DO_DEBUG( opal_output( 0, "_copy_content_same_ddt( %p, %d, dst %p, src %p )\n",
129129
(void*)datatype, count, (void*)destination_base, (void*)source_base ); );
130130

131-
iov_len_local = count * datatype->size;
131+
iov_len_local = (size_t)count * datatype->size;
132132

133133
/* If we have to copy a contiguous datatype then simply
134134
* do a MEM_OP.

opal/datatype/opal_datatype_dump.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ int opal_datatype_dump_data_desc( dt_elem_desc_t* pDesc, int nbElems, char* ptr,
9898
pDesc->end_loop.items, pDesc->end_loop.first_elem_disp,
9999
pDesc->end_loop.size );
100100
else
101-
index += snprintf( ptr + index, length - index, "count %" PRIsize_t " disp 0x%tx (%td) blen %u extent %td (size %zd)\n",
101+
index += snprintf( ptr + index, length - index, "count %u disp 0x%tx (%td) blen %" PRIsize_t " extent %td (size %zd)\n",
102102
pDesc->elem.count, pDesc->elem.disp, pDesc->elem.disp, pDesc->elem.blocklen,
103103
pDesc->elem.extent, (pDesc->elem.count * pDesc->elem.blocklen * opal_datatype_basicDatatypes[pDesc->elem.common.type]->size) );
104104
pDesc++;

opal/datatype/opal_datatype_fake_stack.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ int opal_convertor_create_stack_with_pos_general( opal_convertor_t* pConvertor,
4444
int pos_desc; /* actual position in the description of the derived datatype */
4545
size_t lastLength = 0;
4646
const opal_datatype_t* pData = pConvertor->pDesc;
47-
size_t loop_length, *remoteLength, remote_size;
47+
size_t loop_length, *remoteLength, remote_size;
4848
size_t resting_place = starting_point;
4949
dt_elem_desc_t* pElems;
5050
size_t count;
@@ -152,7 +152,7 @@ int opal_convertor_create_stack_with_pos_general( opal_convertor_t* pConvertor,
152152
if( OPAL_DATATYPE_LOOP == pElems->elem.common.type ) {
153153
remoteLength[pConvertor->stack_pos] += loop_length;
154154
PUSH_STACK( pStack, pConvertor->stack_pos, pos_desc, OPAL_DATATYPE_LOOP,
155-
pElems->loop.loops, pStack->disp );
155+
pElems->loop.loops, pStack->disp );
156156
pos_desc++;
157157
pElems++;
158158
remoteLength[pConvertor->stack_pos] = 0;
@@ -161,7 +161,7 @@ int opal_convertor_create_stack_with_pos_general( opal_convertor_t* pConvertor,
161161
while( pElems->elem.common.flags & OPAL_DATATYPE_FLAG_DATA ) {
162162
/* now here we have a basic datatype */
163163
const opal_datatype_t* basic_type = BASIC_DDT_FROM_ELEM( (*pElems) );
164-
lastLength = pElems->elem.count * basic_type->size;
164+
lastLength = (size_t)pElems->elem.count * basic_type->size;
165165
if( resting_place < lastLength ) {
166166
int32_t cnt = (int32_t)(resting_place / basic_type->size);
167167
loop_length += (cnt * basic_type->size);

opal/datatype/opal_datatype_get_count.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,14 @@ ssize_t opal_datatype_get_element_count( const opal_datatype_t* datatype, size_t
6969
while( pElems[pos_desc].elem.common.flags & OPAL_DATATYPE_FLAG_DATA ) {
7070
/* now here we have a basic datatype */
7171
const opal_datatype_t* basic_type = BASIC_DDT_FROM_ELEM(pElems[pos_desc]);
72-
local_size = (pElems[pos_desc].elem.count * pElems[pos_desc].elem.blocklen) * basic_type->size;
72+
local_size = ((size_t)pElems[pos_desc].elem.count * pElems[pos_desc].elem.blocklen) * basic_type->size;
7373
if( local_size >= iSize ) {
7474
local_size = iSize / basic_type->size;
7575
nbElems += (int32_t)local_size;
7676
iSize -= local_size * basic_type->size;
7777
return (iSize == 0 ? nbElems : -1);
7878
}
79-
nbElems += (pElems[pos_desc].elem.count * pElems[pos_desc].elem.blocklen);
79+
nbElems += ((size_t)pElems[pos_desc].elem.count * pElems[pos_desc].elem.blocklen);
8080
iSize -= local_size;
8181
pos_desc++; /* advance to the next data */
8282
}
@@ -131,7 +131,7 @@ int32_t opal_datatype_set_element_count( const opal_datatype_t* datatype, size_t
131131
while( pElems[pos_desc].elem.common.flags & OPAL_DATATYPE_FLAG_DATA ) {
132132
/* now here we have a basic datatype */
133133
const opal_datatype_t* basic_type = BASIC_DDT_FROM_ELEM(pElems[pos_desc]);
134-
local_length = (pElems[pos_desc].elem.count * pElems[pos_desc].elem.blocklen);
134+
local_length = ((size_t)pElems[pos_desc].elem.count * pElems[pos_desc].elem.blocklen);
135135
if( local_length >= count ) {
136136
*length += count * basic_type->size;
137137
return 0;
@@ -188,10 +188,10 @@ int opal_datatype_compute_ptypes( opal_datatype_t* datatype )
188188
}
189189
while( pElems[pos_desc].elem.common.flags & OPAL_DATATYPE_FLAG_DATA ) {
190190
/* now here we have a basic datatype */
191-
datatype->ptypes[pElems[pos_desc].elem.common.type] += pElems[pos_desc].elem.count * pElems[pos_desc].elem.blocklen;
192-
nbElems += pElems[pos_desc].elem.count * pElems[pos_desc].elem.blocklen;
191+
datatype->ptypes[pElems[pos_desc].elem.common.type] += (size_t)pElems[pos_desc].elem.count * pElems[pos_desc].elem.blocklen;
192+
nbElems += (size_t)pElems[pos_desc].elem.count * pElems[pos_desc].elem.blocklen;
193193

194-
DUMP( " compute_ptypes-add: type %d count %"PRIsize_t" (total type %"PRIsize_t" total %lld)\n",
194+
DUMP( " compute_ptypes-add: type %d count %"PRIsize_t" (total type %u total %lld)\n",
195195
pElems[pos_desc].elem.common.type, datatype->ptypes[pElems[pos_desc].elem.common.type],
196196
pElems[pos_desc].elem.count, nbElems );
197197
pos_desc++; /* advance to the next data */

opal/datatype/opal_datatype_internal.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ typedef struct ddt_elem_id_description ddt_elem_id_description;
156156
*/
157157
struct ddt_elem_desc {
158158
ddt_elem_id_description common; /**< basic data description and flags */
159-
uint32_t blocklen; /**< number of elements on each block */
160-
size_t count; /**< number of blocks */
159+
uint32_t count; /**< number of blocks */
160+
size_t blocklen; /**< number of elements on each block */
161161
ptrdiff_t extent; /**< extent of each block (in bytes) */
162162
ptrdiff_t disp; /**< displacement of the first block */
163163
};

opal/datatype/opal_datatype_pack.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ opal_generic_simple_pack_function( opal_convertor_t* pConvertor,
275275
iov_len_local = iov[iov_count].iov_len;
276276

277277
if( pElem->elem.common.flags & OPAL_DATATYPE_FLAG_DATA ) {
278-
if( (pElem->elem.count * pElem->elem.blocklen) != count_desc ) {
278+
if( ((size_t)pElem->elem.count * pElem->elem.blocklen) != count_desc ) {
279279
/* we have a partial (less than blocklen) basic datatype */
280280
int rc = PACK_PARTIAL_BLOCKLEN( pConvertor, pElem, count_desc,
281281
conv_ptr, iov_ptr, iov_len_local );

0 commit comments

Comments
 (0)