diff --git a/src/qonnx/analysis/inference_cost.py b/src/qonnx/analysis/inference_cost.py index c821d26a..d5da873b 100644 --- a/src/qonnx/analysis/inference_cost.py +++ b/src/qonnx/analysis/inference_cost.py @@ -31,6 +31,96 @@ from qonnx.util.basic import get_by_name +def patches_for_each_pixel_2d(image_h, image_w, kernel_h, kernel_w, stride_h, stride_w): + """ + Computes the number of patches covering each pixel in a 2D image. + + Parameters: + image_h, image_w : int + Height and width of the image. + kernel_h, kernel_w : int + Height and width of the kernel. + stride_h, stride_w : int + Height and width of the stride. + + Returns: + np.array + Matrix indicating the number of patches covering each pixel. + """ + patches_for_each_pixel = np.zeros((image_h, image_w), dtype=int) + + patches_idx_w = select_idx_for_patches_1d(image_w, kernel_w, stride_w) + patches_idx_h = select_idx_for_patches_1d(image_h, kernel_h, stride_h) + + patches_for_each_pixel_h = num_patches_for_each_pixel(patches_idx_h, image_h) + patches_for_each_pixel_w = num_patches_for_each_pixel(patches_idx_w, image_w) + + for j in range(image_h): + for k in range(image_w): + patches_for_each_pixel[j, k] = patches_for_each_pixel_h[j] * patches_for_each_pixel_w[k] + + return patches_for_each_pixel + + +def select_idx_for_patches_1d(img_size, kernel, stride): + """ + Computes the start and end indices for patches in 1D. + + Parameters: + img_size : int + Size of the image along one dimension. + kernel : int + Size of the kernel. + stride : int + Size of the stride. + + Returns: + list of np.array + List containing patch start and end indices for each stride. + """ + patches_idx = [] + + num_patches = int(np.ceil((img_size - kernel) / stride + 1)) + indices = np.zeros((num_patches, 2), dtype=int) + + for i in range(num_patches): + ri_index = 1 + stride * i + rf_index = kernel + stride * i + if rf_index > img_size and stride > 1: + rf_index = img_size + ri_index = img_size - kernel + 1 + + indices[i, 0] = ri_index + indices[i, 1] = rf_index + + patches_idx.append(indices) + + return patches_idx + + +def num_patches_for_each_pixel(patches_idx, img_size): + """ + Computes how many patches cover each pixel in a 1D image slice. + + Parameters: + patches_idx : list of np.array + Patch indices computed from `select_idx_for_patches_1d`. + img_size : int + Size of the image along one dimension. + + Returns: + np.array + Array indicating how many patches cover each pixel. + """ + patches_for_each_pixel1D = np.zeros(img_size, dtype=int) + + for indices in patches_idx: + for start, end in indices: + patches_for_each_pixel1D[start - 1 : end] += 1 + + return patches_for_each_pixel1D + + def get_node_tensor_dtypes(model, node): # input tensor (input 0) i_name = node.input[0] @@ -123,6 +213,63 @@ def inference_cost_conv(model, node, discount_sparsity): return ret +def inference_cost_convtranspose(model, node, discount_sparsity): + # extract info about the conv kernel attributes + k = get_by_name(node.attribute, "kernel_shape").ints + k_h = k[0] + k_w = k[1] + k_prod = np.prod(k) + stride = get_by_name(node.attribute, "strides").ints + s_h = stride[0] + s_w = stride[1] + + # extract info from tensor shapes and datatypes + (i_dtype, w_dtype, o_dtype) = get_node_tensor_dtypes(model, node) + (i_shape, w_shape, o_shape) = get_node_tensor_shapes(model, node) + + bsize = i_shape[0] + ifm_ch = i_shape[1] + ofm_ch = o_shape[1] + assert ofm_ch == w_shape[1], "Mismatch in output channels" + ifm_pix_total = np.prod(i_shape[2:]) + + # The number of approximate (the output is not spatially summed) macs depends on the + # batch size (bsize), the kernel size (k_prod), the number of input channels (ifm_ch), + # the number of output channels (ofm_ch) and the input feature map (ifm_pix_total) + n_macs = bsize * k_prod * ifm_ch * ofm_ch * ifm_pix_total + + # Apart from that, if the stride shape is smaller than the kernel shape, there is ... + # an overlap in the output and there will be extra addition operations to perform + if s_h < k_h or s_w < k_w: + # We have to calculate the overlap (overlap of 1 means no overlap so we have to + # substract a 1 from every position to know the number of extra addition operations) + result = np.sum(patches_for_each_pixel_2d(o_shape[2], o_shape[3], k_h, k_w, s_h, s_w) - 1) + n_macs = n_macs + result + + w_mem = np.prod(w_shape) + o_mem = np.prod(o_shape) + + if discount_sparsity: + wname = node.input[1] + density = get_node_weight_density(model, wname) + n_macs *= density + w_mem *= density + + idt_name = i_dtype.name + wdt_name = w_dtype.name + odt_name = o_dtype.name + + mac_op_type_str = "op_mac_%s_%s" % (idt_name, wdt_name) + w_mem_type_str = "mem_w_%s" % (wdt_name) + o_mem_type_str = "mem_o_%s" % (odt_name) + + # keep in floats to remain compatible with json serialization + n_macs, w_mem, o_mem = float(n_macs), float(w_mem), float(o_mem) + ret = {mac_op_type_str: n_macs, w_mem_type_str: w_mem, o_mem_type_str: o_mem} + + return ret + + def inference_cost_matmul(model, node, discount_sparsity): # extract info from tensor shapes and datatypes (i_dtype, w_dtype, o_dtype) = get_node_tensor_dtypes(model, node) @@ -241,6 +388,7 @@ def inference_cost(model, discount_sparsity=True, cost_breakdown=False): "MatMul": inference_cost_matmul, "Gemm": inference_cost_matmul, "Upsample": inference_cost_upsample, + "ConvTranspose": inference_cost_convtranspose, } for node in model.graph.node: if node.op_type in inference_cost_fxn_map.keys():