|
| 1 | +## 一、概念简介 |
| 2 | +深度学习编译器是一种专门为深度学习模型优化和部署而设计的工具,用于提高模型的计算效率、降低内存占用、加速训练推理过程。其功能是将高层次的深度学习模型转换为低层次的、高效的、底层硬件可执行的代码。简单来说,深度学习编译器在深度学习框架和底层硬件之间充当了“翻译”的角色,能够将用户定义的神经网络模型描述转化为底层硬件能够理解和执行的指令。编译器在实现这种转换的过程中,应用了一系列优化技术,以提高模型在各种硬件平台上(如 CPU、GPU)的执行效率。 |
| 3 | +深度学习编译器的主要功能包括: |
| 4 | +- **模型转换**:将高层次的深度学习模型转换为适合目标硬件的中间表示(IR)。 |
| 5 | +- **优化**:应用各种编译优化技术,如图优化、内存优化、算子融合等,以提高执行效率。 |
| 6 | +- **代码生成**:生成适合目标硬件的可执行代码。 |
| 7 | + |
| 8 | +## 二、背景与动机 |
| 9 | +深度学习模型的训练和推理过程涉及大量的计算,对硬件性能要求很高。飞桨框架虽然提供了高级的编程接口和丰富的算子库,但在执行效率和模型部署方面还有很大的优化空间。使用深度学习编译器的主要动机包括: |
| 10 | +#### 1. 优化性能与资源利用率 |
| 11 | +深度学习模型往往需要处理大量的数据和复杂的计算,直接在高层次框架上执行可能无法充分利用底层硬件的能力。深度学习编译器能够深入硬件特性,应用多种优化技术,提高计算效率,降低延迟。并且通过优化模型的计算图和内存使用,深度学习编译器也能够明显降低模型的内存和 IO 资源的消耗,进而提高计算性能。 |
| 12 | +#### 2. 硬件多样性支持 |
| 13 | +不同的硬件平台有不同的特性和优化需求。在现有机制下,新的异构硬件设备接入深度学习框架需要手工实现几百个算子对应的硬件 Kernel 代码,开发的工作量非常大。如果使用深度学习编译器,理论上仅需实现新硬件 IR 层面的对接,以及相应的硬件 IR 优化策略就能完成与深度学习框架的对接,相比于实现几百个硬件 Kernel,开发的工作量会大幅减少。 |
| 14 | +#### 3. 提升开发效率 |
| 15 | +深度学习编译器可以自动化许多优化过程,减少手动调优的工作量。开发者只需关注模型的设计和训练,而不必深入了解底层硬件优化细节,从而提高开发效率。 |
| 16 | + |
| 17 | +## 三、使用示例: |
| 18 | +飞桨框架编译器(CINN)使用时仅需在原先的模型动转静或推理流程下打开编译器相关 FLAGS 即可,无需对模型代码做任何改动。以下是一个使用样例: |
| 19 | + |
| 20 | +示例代码文件:`run_net.py` |
| 21 | +```python |
| 22 | +import paddle |
| 23 | +from paddle import nn |
| 24 | +from paddle.static import InputSpec |
| 25 | + |
| 26 | +# 定义神经网络 |
| 27 | +class RMSNorm(nn.Layer): |
| 28 | + def __init__(self): |
| 29 | + super().__init__() |
| 30 | + paddle.seed(2024) |
| 31 | + self.hidden_size = 768 |
| 32 | + self.weight = paddle.randn([self.hidden_size], dtype="float32") |
| 33 | + self.variance_epsilon = 1e-6 |
| 34 | + |
| 35 | + def forward(self, hidden_states): |
| 36 | + variance = (hidden_states * hidden_states).sum(-1, keepdim=True) / 768 |
| 37 | + hidden_states = ( |
| 38 | + paddle.rsqrt(variance + self.variance_epsilon) * hidden_states |
| 39 | + ) |
| 40 | + return hidden_states * self.weight |
| 41 | + |
| 42 | + |
| 43 | +def run_net(input_data): |
| 44 | + net = RMSNorm() |
| 45 | + |
| 46 | + # 指定输入变量的维度、数据类型等信息,具体接口可参考: |
| 47 | + # https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/jit/basic_usage_cn.html#inputspec |
| 48 | + input_spec = [ |
| 49 | + InputSpec(shape=[1, None, 768], dtype='float32'), |
| 50 | + ] |
| 51 | + net = paddle.jit.to_static( |
| 52 | + net, |
| 53 | + input_spec=input_spec, |
| 54 | + full_graph=True, |
| 55 | + ) |
| 56 | + # 使用 eval 模式 |
| 57 | + net.eval() |
| 58 | + # 执行计算图 |
| 59 | + out = net(input_data) |
| 60 | + return out |
| 61 | + |
| 62 | +# 创建输入数据 |
| 63 | +input_data = paddle.randn([1, 2048, 768], dtype="float32") |
| 64 | +# 运行神经网络 |
| 65 | +out = run_net(input_data) |
| 66 | +print(out) |
| 67 | +``` |
| 68 | + |
| 69 | +脚本执行:`run.sh` |
| 70 | +``` |
| 71 | +# 打开组合算子 |
| 72 | +export FLAGS_prim_enable_dynamic=true && export FLAGS_prim_all=true |
| 73 | +
|
| 74 | +# 打开 CINN 编译器相关 FLAG |
| 75 | +export FLAGS_use_cinn=true |
| 76 | +export FLAGS_cinn_new_group_scheduler=true |
| 77 | +export FLAGS_group_schedule_tiling_first=true |
| 78 | +export FLAGS_cinn_bucket_compile=true |
| 79 | +
|
| 80 | +# 打开 PIR 模式 |
| 81 | +export FLAGS_enable_pir_api=true |
| 82 | +
|
| 83 | +# 是否打印 Program IR 信息 |
| 84 | +export FLAGS_print_ir=false |
| 85 | +
|
| 86 | +python run_net.py |
| 87 | +``` |
| 88 | + |
| 89 | +上述代码示例中我们创建了一个简单的`rms_norm`计算子图,使用飞桨的动转静流程将子图转为静态图并调用编译器 CINN 进行优化和执行。经过性能对比测试,在 A100 GPU 环境中上述子图使用 CINN 可以取得 3 倍左右的性能提升(该性能数据仅供学习参考,在实际应用模型中能够取得的性能提升效果一般会低于该数据)。 |
| 90 | + |
| 91 | +注:由于飞桨的编译器仍然处在快速迭代开发阶段,我们设置了较多 FLAGS 进行分支的选择和调试,因此现阶段在使用 CINN 时需要对如下 FLAGS(`FLAGS_prim_enable_dynamic`、 `FLAGS_cinn_new_group_scheduler`、 `FLAGS_group_schedule_tiling_first`、 `FLAGS_cinn_bucket_compile`、 `FLAGS_enable_pir_api`) 进行手动设置,待后续相关功能完备后这些 FLAGS 会默认开启,无需再手动设置。 |
| 92 | + |
| 93 | +## 四、设计架构 |
| 94 | +<center><img src="https://github.com/PaddlePaddle/docs/blob/develop/docs/guides/images/cinn/cinn_design.png?raw=true" width="800" ></center> |
| 95 | +<br><center>图 1 CINN 整体架构 </center> |
| 96 | + |
| 97 | +飞桨框架编译器(CINN, Compiler Infrastructure for Neural Networks)整体架构如上图所示,大体可以分为三个模块,分别是编译器前端、编译器后端和执行器部分。 |
| 98 | + |
| 99 | +### 1. 编译器前端 |
| 100 | +一般来说编译器前端需要将不同框架和格式的深度学习模型转换为编译器的内部 IR 并进行图级别的优化,CINN 作为飞桨框架原生编译器,可以直接使用飞桨框架提供的模型加载和中间表示(Paddle IR,简称 PIR)组件,因此 CINN 前端的主要功能是基于 PIR 进行图层级别的优化,并对子图进行划分为后端高性能 Kernel 代码生成提供支持。CINN 前端关键的流程可分为三部分: |
| 101 | + |
| 102 | +#### a. 组合算子拆分 |
| 103 | +飞桨框架中将算子划分为基础算子(也称作原子算子,语义上该算子无法更进一步拆分成其他算子。基础算子语义上可以通过重组等价实现组合算子的逻辑)和非基础算子两类大,由于非基础算子数量较多,并且在编译器中较难识别和处理,因此我们使用组合算子拆分的方式将非基础算子拆分为等价的基础算子组合,原始计算图经过组合算子拆分后可以大幅提升性能的可优化空间。 |
| 104 | + |
| 105 | +#### b. 图优化 Pass |
| 106 | +在计算图层级进行 PIR 的 Pass 优化,常见的图优化 Pass 包括:常量折叠、死代码消除(DCE)、公共子表达式消除(CSE)、冗余算子消除、算子计算合并等。 |
| 107 | + |
| 108 | +#### c. 算子融合 |
| 109 | +算子融合是编译器前端非常重要的一个功能,主要是将多个算子打包到一个子图中(对应为一个 FusionOp),交给编译器后端生成一个高效的硬件相关计算 Kernel。 |
| 110 | +算子融合的本质是通过 IO 优化加速访存密集算子,如果我们将两个连续 Kernel 合并为一个 Kernel 调用,我们会减少中间变量的读写开销,因此在访存密集型的 2 个 Op 上,融合可以获取更高的性能。举个例子,如下图: |
| 111 | +<center><img src="https://github.com/PaddlePaddle/docs/blob/develop/docs/guides/images/cinn/op_fusion.png?raw=true" width="200" ></center> |
| 112 | +<br><center>图 2 算子融合示例 </center> |
| 113 | + |
| 114 | +我们有两个算子 Relu 和 Scale,因为两个算子都是 IO 密集型算子(计算复杂度不高)。正常情况下我们需要读取 A 和 B 一次,写 B 和 C 一次。但是对于融合之后的 Kernel(右图)而言,我们只需要读取 A 和写 C 一次,这样我们通过算子融合可以取得更少的访存次数,在 IO 密集算子而言,可以极大提高性能。 |
| 115 | +具体的算子融合策略实现非常复杂,这里不做展开介绍,感兴趣的读者可以阅读相关源码 #cinn_group_cluster_pass。 |
| 116 | + |
| 117 | +### 2. 编译器后端 |
| 118 | +编译器后端主要负责将前端处理后的 IR 转换为目标硬件可执行的代码或硬件描述。主要功能包括基于硬件特性的 IR 优化、高效内存管理和代码生成等。 |
| 119 | + |
| 120 | +#### 2.1. CINN AST IR |
| 121 | +AST IR 打印示例: |
| 122 | +``` |
| 123 | +ScheduleBlock(root) |
| 124 | +{ |
| 125 | + serial for (i, 0, 32) |
| 126 | + { |
| 127 | + serial for (j_0, 0, 64) |
| 128 | + { |
| 129 | + serial for (j_1, 0, 128) |
| 130 | + { |
| 131 | + ScheduleBlock(A) |
| 132 | + { |
| 133 | + vi, vj = axis.bind(i, j_0 * 64 + j_1) // tensor 下标与循环变量的仿射变换 |
| 134 | + A[vi, vj] = X[vi, vj] * 2 |
| 135 | + } |
| 136 | + } |
| 137 | + } |
| 138 | + } |
| 139 | +} |
| 140 | +``` |
| 141 | +CINN AST IR 中包含了以下信息,但集合和映射并不显示使用某种数据结构进行存储。 |
| 142 | + |
| 143 | +  **集合**:语句实例 & 内存单元 **<br>** |
| 144 | +  **映射**:**<br>** |
| 145 | +   访存关系:语句实例 <---> 内存单元 **<br>** |
| 146 | +   依赖关系:语句实例 <---> 语句实例 **<br>** |
| 147 | +   执行顺序:语句实例 -----> 语句实例 **<br>** |
| 148 | + |
| 149 | +  执行顺序 = 语句实例的先后关系 **<br>** |
| 150 | +  语句实例集合范围 = 循环边界 + 循环步长 ------ 循环构成一个带约束的整数空间,即迭代空间,迭代空间决定了语句实例,语句实例充满了迭代空间。 |
| 151 | + |
| 152 | +#### 2.2. 基于 AST IR 的 Schedule |
| 153 | +Schedule 为定义在 CINN AST IR 上的优化策略,常见的 Schedule 包括:LoopAlignment, Tile, Inline, Vectorize, Unroll 等。**<br>** |
| 154 | +以一个组合算子为例模拟可能的 AST 变换过程:**<br>** |
| 155 | + [S1, S2, 1024] ==E=> [S1, S2, 1024] ==R=> [S1, S2] ==E=> [S1, S2] ==B=> [S1, S2, 1024] ==E=> [S1, S2, 1024] |
| 156 | + |
| 157 | +**(1) LowerToAst 得到的结果** |
| 158 | +``` |
| 159 | +// Elemenwise-1 |
| 160 | +serial for (i, 0, S1) |
| 161 | + serial for (j, 0, S2) |
| 162 | + serial for (k, 0, 1024) |
| 163 | + ScheduleBlock(A) |
| 164 | + vi, vj, vk = axis.bind(i, j, k) |
| 165 | + A[vi, vj, vk] = X[vi, vj, vk] * 2 |
| 166 | +// Elemenwise-2 |
| 167 | +serial for (i, 0, S1) |
| 168 | + serial for (j, 0, S2) |
| 169 | + serial for (k, 0, 1024) |
| 170 | + ScheduleBlock(B) |
| 171 | + vi, vj, vk = axis.bind(i, j, k) |
| 172 | + B[vi, vj, vk] = A[vi, vj, vk] + 1 |
| 173 | +// Reduce-1 |
| 174 | +serial for (i, 0, S1) |
| 175 | + serial for (j, 0, S2) |
| 176 | + ScheduleBlock(C__reduce_init) |
| 177 | + vi, vj = axis.bind(i, j) |
| 178 | + C_init[vi, vj] = 0 |
| 179 | +serial for (i, 0, S1) |
| 180 | + serial for (j, 0, S2) |
| 181 | + serial for (k, 0, 1024) // Reduce |
| 182 | + ScheduleBlock(C) |
| 183 | + vi, vj, vk = axis.bind(i, j, k) |
| 184 | + C[vi, vj] = C[vi, vj] + B[vi, vj, vk] |
| 185 | +// Elemenwise-3 |
| 186 | +serial for (i, 0, S1) |
| 187 | + serial for (j, 0, S2) |
| 188 | + ScheduleBlock(D) |
| 189 | + vi, vj = axis.bind(i, j) |
| 190 | + D[vi, vj] = C[vi, vj] * 2 |
| 191 | +// Broadcast-1 |
| 192 | +serial for (i, 0, S1) |
| 193 | + serial for (j, 0, S2) |
| 194 | + serial for (k, 0, 1024) // Broadcast |
| 195 | + ScheduleBlock(E) |
| 196 | + vi, vj, vk = axis.bind(i, j, k) |
| 197 | + E[vi, vj, vk] = D[vi, vj] |
| 198 | +// Elemenwise-4 |
| 199 | +serial for (i, 0, S1) |
| 200 | + serial for (j, 0, S2) |
| 201 | + serial for (k, 0, 1024) |
| 202 | + ScheduleBlock(F) |
| 203 | + vi, vj, vk = axis.bind(i, j, k) |
| 204 | + F[vi, vj, vk] = E[vi, vj, vk] + 1 |
| 205 | +``` |
| 206 | +**(2) 迭代空间对齐** |
| 207 | +``` |
| 208 | +// 所有 ScheduleBlock 的 loop nest 都变为以下 2 种格式中的一种 |
| 209 | +// 1 |
| 210 | +serial for (sp, 0, S1 * S2) // pure_spatial_iter |
| 211 | + serial for (rb, 0, 1024) // impure_spatial_iter |
| 212 | + ScheduleBlock(XXX) |
| 213 | + vsp1, vsp2, vrb = axis.bind(sp / S2, sp % S2, rb) |
| 214 | + XXX = XXXXXX |
| 215 | +// 2 |
| 216 | +serial for (sp, 0, S1 * S2) // pure_spatial_iter |
| 217 | + ScheduleBlock(XXX) |
| 218 | + vsp1, vsp2 = axis.bind(sp / S2, sp % S2) |
| 219 | + XXX = XXXXXX |
| 220 | +``` |
| 221 | +**(3) Tile: 对所有 ScheduleBlock 的 loop nest 做相同的 Tile** |
| 222 | +``` |
| 223 | +// pure_spatial 轴 Tile 为:-1 * 16 * 64 Tile size 可为参数传入 |
| 224 | +serial for (sp1, 0, S1 * S2 / 1024) |
| 225 | + serial for (sp2, 0, 16) |
| 226 | + serial for (sp3, 0, 64) // S1 * S2 / 16 / 64, predicate: sp1 * 1024 + sp2 * 16 + sp3 < S1 * S2 |
| 227 | + XXXXXX |
| 228 | +// impure_spatial_iter 轴 Tile 为 32 |
| 229 | +serial for (sp1, 0, S1 * S2 / 1024) |
| 230 | + serial for (sp2, 0, 16) |
| 231 | + serial for (sp3, 0, 64) |
| 232 | + serial for (rb1, 0, 32) |
| 233 | + serial for (rb2, 0, 32) |
| 234 | + ScheduleBlock(XXX) |
| 235 | + predicate = sp1 * 1024 + sp2 * 16 + sp3 < S1 * S2 |
| 236 | + vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2) |
| 237 | + vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2) |
| 238 | + vrb = axis.bind(rb1 * 32 + rb2) |
| 239 | + XXX = XXXXX |
| 240 | +``` |
| 241 | +**(4) ComputeInline** |
| 242 | +``` |
| 243 | +// 例如 ScheduleBlock(A) inline 到 ScheduleBlock(B) |
| 244 | +serial for (sp1, 0, S1 * S2 / 1024) |
| 245 | + serial for (sp2, 0, 16) |
| 246 | + serial for (sp3, 0, 64) |
| 247 | + serial for (rb1, 0, 32) |
| 248 | + serial for (rb2, 0, 32) |
| 249 | + ScheduleBlock(A) |
| 250 | + predicate = sp1 * 1024 + sp2 * 16 + sp3 < S1 * S2 |
| 251 | + vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2) |
| 252 | + vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2) |
| 253 | + vrb = axis.bind(rb1 * 32 + rb2) |
| 254 | + B[vsp1, vsp2, vrb] = (X[vsp1, vsp2, vrb] * 2) + 1 |
| 255 | +``` |
| 256 | +**(5) Reduce 优化: two step reduce & 绑定部分 reduce 轴到 cuda** |
| 257 | +``` |
| 258 | +// 为了简洁,此处省略 reduce_init Block 和 predicate |
| 259 | +serial for (sp1, 0, S1 * S2 / 1024) |
| 260 | + serial for (sp2, 0, 16) |
| 261 | + serial for (sp3, 0, 64) |
| 262 | + CudaBind[ThreadIdx.x] for (rb1, 0, 32) |
| 263 | + serial for (rb2, 0, 32) |
| 264 | + ScheduleBlock(C_rf) |
| 265 | + vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2) |
| 266 | + vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2) |
| 267 | + vrb1 = axis.bind(rb1) |
| 268 | + vrb2 = axis.bind(rb2) |
| 269 | + C_rf[vsp1, vsp2, vrb1] = C_rf[vsp1, vsp2, vrb1] + B[vsp1, vsp2, vrb1 * 32 + vrb2] |
| 270 | + ScheduleBlock(C) |
| 271 | + vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2) |
| 272 | + vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2) |
| 273 | + vrb1 = axis.bind(rb1) |
| 274 | + C[vsp1, vsp2] = C[vsp1, vsp2] + C_rf[vsp1, vsp2, vrb1] |
| 275 | +``` |
| 276 | +**(6) 循环融合: ComputeAt && SimpleComputeAt,融合外层循环乘积相同的循环,并且保证不破坏图级别依赖(规则负责)和元素级别依赖(原语负责)** |
| 277 | +``` |
| 278 | +serial for (sp1, 0, S1 * S2 / 1024) |
| 279 | + serial for (sp2, 0, 16) |
| 280 | + serial for (sp3, 0, 64) |
| 281 | + ScheduleBlock(D) |
| 282 | + vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2) |
| 283 | + vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2) |
| 284 | + D[vsp1, vsp2] = C[vsp1, vsp2] * 2 |
| 285 | + serial for (rb1, 0, 32) |
| 286 | + serial for (rb2, 0, 32) |
| 287 | + ScheduleBlock(E) |
| 288 | + vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2) |
| 289 | + vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2) |
| 290 | + vrb = axis.bind(rb1 * 32 + rb2) |
| 291 | + E[vsp1, vsp2, vrb] = D[vsp1, vsp2] |
| 292 | + ScheduleBlock(F) |
| 293 | + vsp1 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) / S2) |
| 294 | + vsp2 = axis.bind((sp1 * 1024 + sp2 * 16 + sp3) % S2) |
| 295 | + vrb = axis.bind(rb1 * 32 + rb2) |
| 296 | + F[vsp1, vsp2, vrb] = E[vsp1, vsp2, vrb] + 1 |
| 297 | +``` |
| 298 | +**(7) Bind Cuda 轴:在第二步中,所有 ScheduleBlock 对应的循环要 bind 到同一 Cuda 轴** |
| 299 | +``` |
| 300 | +serial for (sp1, 0, S1 * S2 / 1024) |
| 301 | + CudaBind[BlockIdx.x] for (sp2, 0, 16) |
| 302 | + CudaBind[ThreadIdx.y] for (sp3, 0, 64) |
| 303 | + CudaBind[ThreadIdx.x] for (rb1, 0, 32) |
| 304 | + serial for (rb2, 0, 32) |
| 305 | + ScheduleBlock(XXX) |
| 306 | +``` |
| 307 | + |
| 308 | +#### 2.3. Kernel 代码生成与编译 |
| 309 | + |
| 310 | +Codegen 在 CINN IR AST 上做前序遍历,打印出对应硬件的指令,并通过硬件相对应的编译器(如 llvm、nvcc 等)进行编译得到可运行的函数指针,该指针会被封装到 `JitKernelOp`` 中用于后续执行器的解析执行。 |
| 311 | + |
| 312 | +a. 以函数定义为例子,cuda kernel func 和 x86 kernel func 的不同的是,cuda kernel func 会在函数名前增加 `__global__` |
| 313 | + |
| 314 | +针对 x86 硬件,转义 `ir::_LoweredFunc_` 的代码如下: |
| 315 | +``` |
| 316 | +void CodeGenC::Visit(const ir::_LoweredFunc_ *op) { |
| 317 | + PrintFunctionDeclaration(op); // 前序遍历继续转义函数名、函数参数等 |
| 318 | + str_ += "\n"; |
| 319 | + ... |
| 320 | + ... |
| 321 | +} |
| 322 | +``` |
| 323 | +在 NV GPU 上的转义代码如下: |
| 324 | +``` |
| 325 | +void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { |
| 326 | + str_ += "__global__\n"; // 和 x86 的不同,增加 __global__ |
| 327 | + PrintFunctionDeclaration(op); // 前序遍历继续转义函数名、函数参数等 |
| 328 | + str_ += "\n"; |
| 329 | + ... |
| 330 | + ... |
| 331 | +} |
| 332 | +``` |
| 333 | +b. 在动态形状场景下,还会 codegen 出 infer shape function, infer shape function 的 CINN IR 会在 Bucket Lowering 中得到,转义过程复用的 x86 硬件的 codegen。infer shape kernel 如下: |
| 334 | +``` |
| 335 | +// infer shape 函数名字的组成:kernel_name + "infer_shape" |
| 336 | +// 函数参数: |
| 337 | +// kernel_args: 指针数组,和 kernel func args 一致 |
| 338 | +// kernel_args_num: kernel_args 的长度 |
| 339 | +// tensor_shape_args: 指针数组,存储输出 tensor 的 shape |
| 340 | +function fn_exp_0_subtract_0_infer_shape (kernel_args, kernel_args_num, tensor_shape_args) |
| 341 | +{ |
| 342 | + int64 S0 = cinn_get_value_in_cuda_kernel_args(kernel_args, 2) |
| 343 | + { |
| 344 | + // CINN IR 暂时不支持数据索引的语法,暂时用函数调用实现,下面 2 条语句等价于 |
| 345 | + // tensor_shape_args[0] = {S0, 256ll}; |
| 346 | + // 即第 0 个出 tensor 的 shape 为{S0, 256ll}; |
| 347 | + infer_shape_set_value(0, 0, S0, tensor_shape_args) |
| 348 | + infer_shape_set_value(0, 1, 256ll, tensor_shape_args) |
| 349 | + } |
| 350 | +} |
| 351 | +``` |
| 352 | + |
| 353 | +### 3. 执行器 |
| 354 | + |
| 355 | +编译器生成的 Kernel 代码需要与深度学习框架执行器完成交互和集成才能最终运行起来,因此需要基于执行器的运行调度接口对编译器生成的 Kernel 进行封装。 |
| 356 | + |
| 357 | +接入执行器后在运行时对于经过编译器处理的子图将执行 CINN 生成的 Kernel, 否则将执行常规的 PHI 算子 Kernel。 |
0 commit comments