A shared library providing with greedy training algorithms for SNNs.
- C90 standard
- Shared library
- AdvanTest T5830 platform (linux)
#define TEST_ON_T5830 1
(ATLio.h @line 8)#define TRAIN_DATA_LOC "path/to/train/data.txt"
(DataLoader.h @line 9)#define TEST_DATA_LOC "path/to/test/data.txt"
(DataLoader.h @line 10)#define RESULTS_LOC "path/to/results/dir/
需以'/'
结尾 (Network.h @line 11)
-
void GetSetPulseConfig(void* v_bl, void* v_sl, void* v_wl, void* pulse_width)
- @desc Return SET pulse config. Voltage unit: mV, width unit: ns.
- @param
v_bl
<@ATL_typeINTEGER[1]
>: BL voltage. - @param
v_sl
<@ATL_typeINTEGER[1]
>: SL voltage. - @param
v_wl
<@ATL_typeINTEGER[1]
>: WL voltage. - @param
pulse_width
<@ATL_typeINTEGER[1]
>: pulse width.
-
void GetResetPulseConfig(void* v_bl, void* v_sl, void* v_wl, void* pulse_width)
- @desc Return RESET pulse config. Voltage unit: mV, width unit: ns.
- @param
v_bl
<@ATL_typeINTEGER[1]
>: BL voltage. - @param
v_sl
<@ATL_typeINTEGER[1]
>: SL voltage. - @param
v_wl
<@ATL_typeINTEGER[1]
>: WL voltage. - @param
pulse_width
<@ATL_typeINTEGER[1]
>: pulse width.
-
void GetReadPulseConfig(void* v_bl, void* v_sl, void* v_wl, void* pulse_width)
- @desc Return READ pulse config. Voltage unit: mV, width unit: ns.
- @param
v_bl
<@ATL_typeINTEGER[1]
>: BL voltage. - @param
v_sl
<@ATL_typeINTEGER[1]
>: SL voltage. - @param
v_wl
<@ATL_typeINTEGER[1]
>: WL voltage. - @param
pulse_width
<@ATL_typeINTEGER[1]
>: pulse width.
-
void StartTrain()
- @desc 初始化, 准备开始训练.
-
void GetTrainInstruction(void* end_of_train, void* operation_type, void* bl_enable, void* sl_enable)
- @desc 获取前向操作指令, 操作类型可能为
READ
,RESET
或EMPTY
. - @param
end_of_train
<@ATL_typeINTEGER[1]
>: 训练结束则end_of_train=1
, 否则为0
. - @param
operation_type
<@ATL_typeINTEGER[1]
>: 操作类型.SET = 0
: 此函数不会返回SET
操作.RESET = 1
: 选中若干行SL以及一行BL进行RESET
操作.- 选中的
SL
信息返回至sl_enable[128]
数组, 选中的置为1
, 未选中的置为0
. - 选中的
BL
返回至bl_enable
整数型变量, 范围0~7.
- 选中的
READ = 2
: 选中选中若干行SL进行READ
操作 (SL
加0.15V读电压, 从BL
端口读电流),BL
全部打开, 得到8个BL
的累加电流 (nA), 暂存.- 选中的
SL
信息返回至sl_enable[128]
数组, 选中的置为1
, 未选中的置为0
. BL
硬编码为全部打开,bl_enable
不返回有意义的值.
- 选中的
EMPTY = 3
: 空操作, 当且仅当训练结束时会返回空操作.
- @param
bl_enable
<@ATL_typeINTEGER[1]
>: 当且仅当operation_type = RESET
时, 此变量返回需要进行RESET
的BL
索引, 范围0~7. - @param
sl_enable
<@ATL_typeINTEGER[1] (128)
>: 长度为128的数组, 值非0
即1
, 表示对应SL
是否打开.
- @desc 获取前向操作指令, 操作类型可能为
-
void GetTrainFeedbackInstruction(void* bl_currents, void* operation_type, void* bl_enable, void* sl_enable)
- @desc 将前向指令
READ
操作读取到的各BL
电流反馈给网络, 并获取反向操作指令, 操作类型可能为SET
或EMPTY
. - @param
bl_currents
<@ATL_typeINTEGER[1] (8)
>: 长度为8的数组, 将前向READ
操作读取到的电流值存储在此变量, 供网络模型读取. 如果前向操作指令为RESET
, 则不关心此变量的具体值. - @param
operation_type
<@ATL_typeINTEGER[1]
>: 操作类型.SET = 0
: 选中若干行SL以及一行BL进行SET
操作.- 选中的
SL
信息返回至sl_enable[128]
数组, 选中的置为1
, 未选中的置为0
. - 选中的
BL
返回至bl_enable
整数型变量, 范围0~7.
- 选中的
RESET = 1
: 此函数不会返回RESET
操作.READ = 2
: 此函数不会返回READ
操作.EMPTY = 3
: 空操作, 此函数返回空操作不代表训练结束.
- @desc 将前向指令
-
void StartTest()
- @desc 训练结束后调用, 准备开始inference.
-
void GetTestInstruction(void* operation_type, void* sl_enable)
- @desc 获取前向操作指令, 操作类型只可能为
READ
. - @param
sl_enable
<@ATL_typeINTEGER[1] (128)
>: 长度为128的数组, 值非0
即1
, 表示对应SL
是否打开.
- @desc 获取前向操作指令, 操作类型只可能为
-
void GetTestFeedbackInstruction(void* bl_currents, void* end_of_test)
- @desc 将前向读取的电流反馈给网络, 并判断inference是否结束.
- @param
bl_currents
<@ATL_typeINTEGER[1] (8)
>: 长度为8的数组, 将前向READ
操作读取到的电流值存储在此变量, 供网络模型读取. - @param
end_of_test
<@ATL_typeINTEGER[1]
>: 推理结束则end_of_test=1
, 否则为0
.
-
void SaveArray(void* bl0_currents, void* bl1_currents, void* bl2_currents, void* bl3_currents, void* bl4_currents, void* bl5_currents, void* bl6_currents, void* bl7_currents)
- @desc 保存当前阵列各cell状态.
- @param
blx_currents
<@ATL_typeINTEGER[1] (128)
>: 长度为128的数组, 第x行BL的各cell电流值(nA).
-
void Save()
- @desc 保存模型到文件.
-
void EvaluateScore()
- @desc 推理结束后调用, 计算准确率并保存至文件.
-
标准FORMING, 未提供接口: 300nA - 4000nA
-
获取脉冲参数配置:
LoadSetPulseConfig()
;LoadResetPulseConfig()
;LoadReadPulseConfig()
-
开始训练:
StartTrain()
-
while True:
GetTrainInstruction()
- 硬件操作阵列
if end_of_train: break
GetTrainFeedbackInstruction()
- 硬件操作阵列
-
保存阵列状态:
SaveArray()
; 保存网络模型:Save()
-
开始推理:
StartTest()
-
while True:
GetTestInstruction()
- 硬件操作阵列
GetTestFeedbackInstruction()
- 硬件操作阵列
if end_of_test: break
-
保存阵列状态:
SaveArray()
; 保存网络模型:Save()
-
计算准确率:
EvaluateScore()
-
v0.2.0
- Refactor APIs.
- Adapt to ATL interfaces.
-
v0.1.0
- New feature:
void InitNetwork()
. - Add utility functions.
- New feature:
-
v0.0.1
- Build program skeleton.
- New feature:
void LoadMNIST()
, loading MNIST image data into arrays from hard-coded data file path.