Skip to content

Commit 7c90a8d

Browse files
committed
fix: whitelist names of embed models
1 parent d62a16c commit 7c90a8d

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

pkg/sql2code/parser/parser.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,9 +780,60 @@ func getModelStructCode(data tmplData, importPaths []string, isEmbed bool, jsonN
780780
structCode = strings.ReplaceAll(structCode, `bson:"id" json:"id"`, `bson:"_id" json:"id"`)
781781
}
782782

783+
tableColumnsCode, err := getTableColumnsCode(data, isEmbed)
784+
if err != nil {
785+
return "", nil, err
786+
}
787+
structCode += string(tableColumnsCode)
788+
783789
return structCode, newImportPaths, nil
784790
}
785791

792+
func getTableColumnsCode(data tmplData, isEmbed bool) ([]byte, error) {
793+
if data.DBDriver == DBDriverMongodb {
794+
for _, field := range data.Fields {
795+
if field.Name == "ID" {
796+
field.ColName = "_id"
797+
data.Fields = append(data.Fields, field)
798+
break
799+
}
800+
}
801+
}
802+
if isEmbed {
803+
var fields = []tmplField{
804+
{
805+
ColName: "id",
806+
},
807+
{
808+
ColName: "created_at",
809+
},
810+
{
811+
ColName: "updated_at",
812+
},
813+
{
814+
ColName: "deleted_at",
815+
},
816+
}
817+
for _, field := range data.Fields {
818+
if field.Name == __mysqlModel__ {
819+
continue
820+
}
821+
fields = append(fields, field)
822+
}
823+
data.Fields = fields
824+
}
825+
builder := strings.Builder{}
826+
err := tableColumnsTmpl.Execute(&builder, data)
827+
if err != nil {
828+
return nil, fmt.Errorf("tableColumnsTmpl.Execute error: %v", err)
829+
}
830+
code, err := format.Source([]byte(builder.String()))
831+
if err != nil {
832+
return nil, fmt.Errorf("tableColumnsTmpl format.Source error: %v", err)
833+
}
834+
return code, err
835+
}
836+
786837
func getModelCode(data modelCodes) (string, error) {
787838
builder := strings.Builder{}
788839
err := modelTmpl.Execute(&builder, data)

pkg/sql2code/parser/template.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ func (m *{{.TableName}}) TableName() string {
2424
return "{{.RawTableName}}"
2525
}
2626
{{end}}
27+
`
28+
29+
tableColumnsTmpl *template.Template
30+
tableColumnsTmplRaw = `
31+
// {{.TableName}}ColumnNames Whitelist for custom query fields to prevent sql injection attacks
32+
var {{.TableName}}ColumnNames = map[string]bool{
33+
{{- range .Fields}}
34+
"{{.ColName}}": true,
35+
{{- end}}
36+
}
2737
`
2838

2939
modelTmpl *template.Template
@@ -730,6 +740,10 @@ func initTemplate() {
730740
if err != nil {
731741
errSum = errors.Wrap(err, "modelStructTmplRaw")
732742
}
743+
tableColumnsTmpl, err = template.New("tableColumns").Parse(tableColumnsTmplRaw)
744+
if err != nil {
745+
errSum = errors.Wrap(err, "tableColumnsTmplRaw")
746+
}
733747
modelTmpl, err = template.New("goFile").Parse(modelTmplRaw)
734748
if err != nil {
735749
errSum = errors.Wrap(errSum, "modelTmplRaw:"+err.Error())

0 commit comments

Comments
 (0)