diff --git a/mysql/resource_user.go b/mysql/resource_user.go index 725e20ae..bf808c1d 100644 --- a/mysql/resource_user.go +++ b/mysql/resource_user.go @@ -160,38 +160,35 @@ func CreateUser(ctx context.Context, d *schema.ResourceData, meta interface{}) d authStm = " IDENTIFIED WITH " + auth } } + + var hashed string if v, ok := d.GetOk("auth_string_hashed"); ok { - hashed := v.(string) + hashed = v.(string) if hashed != "" { if authStm == "" { return diag.Errorf("auth_string_hashed is not supported for auth plugin %s", auth) } - authStm = fmt.Sprintf("%s AS '%s'", authStm, hashed) + authStm = fmt.Sprintf("%s AS ?", authStm) } } var stmtSQL string + var args []interface{} if createObj == "AADUSER" { var aadIdentity = d.Get("aad_identity").(*schema.Set).List()[0].(map[string]interface{}) - if aadIdentity["type"].(string) == "service_principal" { // CREATE AADUSER 'mysqlProtocolLoginName"@"mysqlHostRestriction' IDENTIFIED BY 'identityId' - stmtSQL = fmt.Sprintf("CREATE AADUSER '%s'@'%s' IDENTIFIED BY '%s'", - d.Get("user").(string), - d.Get("host").(string), - aadIdentity["identity"].(string)) + stmtSQL = "CREATE AADUSER ?@? IDENTIFIED BY ?" + args = []interface{}{d.Get("user").(string), d.Get("host").(string), aadIdentity["identity"].(string)} } else { // CREATE AADUSER 'identityName"@"mysqlHostRestriction' AS 'mysqlProtocolLoginName' - stmtSQL = fmt.Sprintf("CREATE AADUSER '%s'@'%s' AS '%s'", - aadIdentity["identity"].(string), - d.Get("host").(string), - d.Get("user").(string)) + stmtSQL = "CREATE AADUSER ?@? AS ?" + args = []interface{}{aadIdentity["identity"].(string), d.Get("host").(string), d.Get("user").(string)} } } else { - stmtSQL = fmt.Sprintf("CREATE USER '%s'@'%s'", - d.Get("user").(string), - d.Get("host").(string)) + stmtSQL = "CREATE USER ?@?" + args = []interface{}{d.Get("user").(string), d.Get("host").(string)} } var password string @@ -206,47 +203,45 @@ func CreateUser(ctx context.Context, d *schema.ResourceData, meta interface{}) d } if authStm != "" { - stmtSQL = stmtSQL + authStm + stmtSQL += authStm + if hashed != "" { + args = append(args, hashed) + } if password != "" { - stmtSQL = stmtSQL + fmt.Sprintf(" BY '%s'", password) + stmtSQL += " BY ?" + args = append(args, password) } } else if password != "" { - stmtSQL = stmtSQL + fmt.Sprintf(" IDENTIFIED BY '%s'", password) + stmtSQL += " IDENTIFIED BY ?" + args = append(args, password) } requiredVersion, _ := version.NewVersion("5.7.0") - - var updateStmtSql = "" + var updateStmtSql string + var updateArgs []interface{} if getVersionFromMeta(ctx, meta).GreaterThan(requiredVersion) && d.Get("tls_option").(string) != "" { if createObj == "AADUSER" { - updateStmtSql = fmt.Sprintf("ALTER USER '%s'@'%s' REQUIRE %s", - d.Get("user").(string), - d.Get("host").(string), - d.Get("tls_option").(string)) + updateStmtSql = "ALTER USER ?@? REQUIRE " + d.Get("tls_option").(string) + updateArgs = []interface{}{d.Get("user").(string), d.Get("host").(string)} } else { - stmtSQL += fmt.Sprintf(" REQUIRE %s", d.Get("tls_option").(string)) + stmtSQL += " REQUIRE " + d.Get("tls_option").(string) } } - retainPassword := d.Get("retain_old_password").(bool) - if retainPassword { - err := checkRetainCurrentPasswordSupport(ctx, meta) - if err != nil { - return diag.Errorf("cannot use retain_current_password: %v", err) + // Redact sensitive values in args for logging + redactedArgs := make([]interface{}, len(args)) + for i, arg := range args { + if (password != "" && arg == password) || (hashed != "" && arg == hashed) { + redactedArgs[i] = "" + } else { + redactedArgs[i] = arg } } - discardOldPassword := d.Get("discard_old_password").(bool) - if discardOldPassword { - err := checkDiscardOldPasswordSupport(ctx, meta) - if err != nil { - return diag.Errorf("cannot use discard_old_password: %v", err) - } - } + log.Println("[DEBUG] Executing statement:", stmtSQL, "args:", redactedArgs) - log.Println("[DEBUG] Executing statement:", stmtSQL) - _, err = db.ExecContext(ctx, stmtSQL) + _, err = db.ExecContext(ctx, stmtSQL, args...) if err != nil { return diag.Errorf("failed executing SQL: %v", err) } @@ -255,8 +250,8 @@ func CreateUser(ctx context.Context, d *schema.ResourceData, meta interface{}) d d.SetId(user) if updateStmtSql != "" { - log.Println("[DEBUG] Executing statement:", updateStmtSql) - _, err = db.ExecContext(ctx, updateStmtSql) + log.Println("[DEBUG] Executing statement:", updateStmtSql, "args:", updateArgs) + _, err = db.ExecContext(ctx, updateStmtSql, updateArgs...) if err != nil { d.Set("tls_option", "") return diag.Errorf("failed executing SQL: %v", err)