@@ -71,37 +71,36 @@ void translateParam(
71
71
// Check if the param has already been converted to halide components.
72
72
if (params->find (p.ident ().name ()) != params->end ()) {
73
73
return ;
74
- } else {
75
- lang::TensorType type = p.tensorType ();
76
- int dimensions = (int )type.dims ().size ();
77
- ImageParam imageParam (
78
- translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
79
- inputs->push_back (imageParam);
80
- vector<Expr> dims;
81
- for (auto d_ : type.dims ()) {
82
- if (d_->kind () == lang::TK_IDENT) {
83
- auto d = lang::Ident (d_);
84
- auto it = params->find (d.name ());
85
- Parameter p;
86
- if (it != params->end ()) {
87
- p = it->second ;
88
- } else {
89
- p = Parameter (Int (32 ), false , 0 , d.name (), true );
90
- (*params)[d.name ()] = p;
91
- }
92
- dims.push_back (Variable::make (Int (32 ), p.name (), p));
74
+ }
75
+ lang::TensorType type = p.tensorType ();
76
+ int dimensions = (int )type.dims ().size ();
77
+ ImageParam imageParam (
78
+ translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
79
+ inputs->push_back (imageParam);
80
+ vector<Expr> dims;
81
+ for (auto d_ : type.dims ()) {
82
+ if (d_->kind () == lang::TK_IDENT) {
83
+ auto d = lang::Ident (d_);
84
+ auto it = params->find (d.name ());
85
+ Parameter p;
86
+ if (it != params->end ()) {
87
+ p = it->second ;
93
88
} else {
94
- CHECK (d_->kind () == lang::TK_CONST);
95
- int32_t value = lang::Const (d_).value ();
96
- dims.push_back (Expr (value));
89
+ p = Parameter (Int (32 ), false , 0 , d.name (), true );
90
+ (*params)[d.name ()] = p;
97
91
}
92
+ dims.push_back (Variable::make (Int (32 ), p.name (), p));
93
+ } else {
94
+ CHECK (d_->kind () == lang::TK_CONST);
95
+ int32_t value = lang::Const (d_).value ();
96
+ dims.push_back (Expr (value));
98
97
}
98
+ }
99
99
100
- for (int i = 0 ; i < imageParam.dimensions (); i++) {
101
- imageParam.dim (i).set_bounds (0 , dims[i]);
102
- }
103
- (*params)[imageParam.name ()] = imageParam.parameter ();
100
+ for (int i = 0 ; i < imageParam.dimensions (); i++) {
101
+ imageParam.dim (i).set_bounds (0 , dims[i]);
104
102
}
103
+ (*params)[imageParam.name ()] = imageParam.parameter ();
105
104
}
106
105
107
106
void translateOutput (
0 commit comments