Skip to content

updated inference_cost.py in order to include support for ConvTranspose #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions src/qonnx/analysis/inference_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,96 @@
from qonnx.util.basic import get_by_name


def patches_for_each_pixel_2d(image_h, image_w, kernel_h, kernel_w, stride_h, stride_w):
"""
Computes the number of patches covering each pixel in a 2D image.

Parameters:
image_h, image_w : int
Height and width of the image.
kernel_h, kernel_w : int
Height and width of the kernel.
stride_h, stride_w : int
Height and width of the stride.

Returns:
np.array
Matrix indicating the number of patches covering each pixel.
"""
patches_for_each_pixel = np.zeros((image_h, image_w), dtype=int)

patches_idx_w = select_idx_for_patches_1d(image_w, kernel_w, stride_w)
patches_idx_h = select_idx_for_patches_1d(image_h, kernel_h, stride_h)

patches_for_each_pixel_h = num_patches_for_each_pixel(patches_idx_h, image_h)
patches_for_each_pixel_w = num_patches_for_each_pixel(patches_idx_w, image_w)

for j in range(image_h):
for k in range(image_w):
patches_for_each_pixel[j, k] = patches_for_each_pixel_h[j] * patches_for_each_pixel_w[k]

return patches_for_each_pixel


def select_idx_for_patches_1d(img_size, kernel, stride):
"""
Computes the start and end indices for patches in 1D.

Parameters:
img_size : int
Size of the image along one dimension.
kernel : int
Size of the kernel.
stride : int
Size of the stride.

Returns:
list of np.array
List containing patch start and end indices for each stride.
"""
patches_idx = []

num_patches = int(np.ceil((img_size - kernel) / stride + 1))
indices = np.zeros((num_patches, 2), dtype=int)

for i in range(num_patches):
ri_index = 1 + stride * i
rf_index = kernel + stride * i
if rf_index > img_size and stride > 1:
rf_index = img_size
ri_index = img_size - kernel + 1

indices[i, 0] = ri_index
indices[i, 1] = rf_index

patches_idx.append(indices)

return patches_idx


def num_patches_for_each_pixel(patches_idx, img_size):
"""
Computes how many patches cover each pixel in a 1D image slice.

Parameters:
patches_idx : list of np.array
Patch indices computed from `select_idx_for_patches_1d`.
img_size : int
Size of the image along one dimension.

Returns:
np.array
Array indicating how many patches cover each pixel.
"""
patches_for_each_pixel1D = np.zeros(img_size, dtype=int)

for indices in patches_idx:
for start, end in indices:
patches_for_each_pixel1D[start - 1 : end] += 1

return patches_for_each_pixel1D


def get_node_tensor_dtypes(model, node):
# input tensor (input 0)
i_name = node.input[0]
Expand Down Expand Up @@ -123,6 +213,63 @@ def inference_cost_conv(model, node, discount_sparsity):
return ret


def inference_cost_convtranspose(model, node, discount_sparsity):
# extract info about the conv kernel attributes
k = get_by_name(node.attribute, "kernel_shape").ints
k_h = k[0]
k_w = k[1]
k_prod = np.prod(k)
stride = get_by_name(node.attribute, "strides").ints
s_h = stride[0]
s_w = stride[1]

# extract info from tensor shapes and datatypes
(i_dtype, w_dtype, o_dtype) = get_node_tensor_dtypes(model, node)
(i_shape, w_shape, o_shape) = get_node_tensor_shapes(model, node)

bsize = i_shape[0]
ifm_ch = i_shape[1]
ofm_ch = o_shape[1]
assert ofm_ch == w_shape[1], "Mismatch in output channels"
ifm_pix_total = np.prod(i_shape[2:])

# The number of approximate (the output is not spatially summed) macs depends on the
# batch size (bsize), the kernel size (k_prod), the number of input channels (ifm_ch),
# the number of output channels (ofm_ch) and the input feature map (ifm_pix_total)
n_macs = bsize * k_prod * ifm_ch * ofm_ch * ifm_pix_total

# Apart from that, if the stride shape is smaller than the kernel shape, there is ...
# an overlap in the output and there will be extra addition operations to perform
if s_h < k_h or s_w < k_w:
# We have to calculate the overlap (overlap of 1 means no overlap so we have to
# substract a 1 from every position to know the number of extra addition operations)
result = np.sum(patches_for_each_pixel_2d(o_shape[2], o_shape[3], k_h, k_w, s_h, s_w) - 1)
n_macs = n_macs + result

w_mem = np.prod(w_shape)
o_mem = np.prod(o_shape)

if discount_sparsity:
wname = node.input[1]
density = get_node_weight_density(model, wname)
n_macs *= density
w_mem *= density

idt_name = i_dtype.name
wdt_name = w_dtype.name
odt_name = o_dtype.name

mac_op_type_str = "op_mac_%s_%s" % (idt_name, wdt_name)
w_mem_type_str = "mem_w_%s" % (wdt_name)
o_mem_type_str = "mem_o_%s" % (odt_name)

# keep in floats to remain compatible with json serialization
n_macs, w_mem, o_mem = float(n_macs), float(w_mem), float(o_mem)
ret = {mac_op_type_str: n_macs, w_mem_type_str: w_mem, o_mem_type_str: o_mem}

return ret


def inference_cost_matmul(model, node, discount_sparsity):
# extract info from tensor shapes and datatypes
(i_dtype, w_dtype, o_dtype) = get_node_tensor_dtypes(model, node)
Expand Down Expand Up @@ -241,6 +388,7 @@ def inference_cost(model, discount_sparsity=True, cost_breakdown=False):
"MatMul": inference_cost_matmul,
"Gemm": inference_cost_matmul,
"Upsample": inference_cost_upsample,
"ConvTranspose": inference_cost_convtranspose,
}
for node in model.graph.node:
if node.op_type in inference_cost_fxn_map.keys():
Expand Down