Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions protoc-gen-openapiv2/internal/genopenapi/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,38 @@ func renderMessageAsDefinition(msg *descriptor.Message, reg *descriptor.Registry
}

if fieldSchema.Required != nil {
schema.Required = getUniqueFields(schema.Required, fieldSchema.Required)
schema.Required = append(schema.Required, fieldSchema.Required...)
// To avoid populating both the field schema require and message schema require, unset the field schema require.
// See issue #2635.
fieldSchema.Required = nil
// Only hoist required fields to parent if there are no path params inside this field.
if len(subPathParams) == 0 {
schema.Required = getUniqueFields(schema.Required, fieldSchema.Required)
schema.Required = append(schema.Required, fieldSchema.Required...)
// To avoid populating both the field schema require and message schema require, unset the field schema require.
// See issue #2635.
fieldSchema.Required = nil
} else {
// When there are path params, we need to separate field-level required from nested required.
// The field name itself (if required) should be in parent's required, but nested field names
// should stay in the nested schema's required.
fieldName := f.GetName()
if reg.GetUseJSONNamesForFields() {
fieldName = f.GetJsonName()
}
// Check if the field name is in the fieldSchema.Required (it would be if the field is marked REQUIRED)
var nestedRequired []string
fieldIsRequired := false
for _, req := range fieldSchema.Required {
if req == fieldName {
fieldIsRequired = true
} else {
nestedRequired = append(nestedRequired, req)
}
}
// Add the field name to parent's required if the field itself is required
if fieldIsRequired && find(schema.Required, fieldName) == -1 {
schema.Required = append(schema.Required, fieldName)
}
// Keep only the nested required fields in the field schema
fieldSchema.Required = nestedRequired
}
}

if reg.GetUseAllOfForRefs() {
Expand Down
223 changes: 223 additions & 0 deletions protoc-gen-openapiv2/internal/genopenapi/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11437,3 +11437,226 @@ func Test_updateSwaggerObjectFromFieldBehavior(t *testing.T) {
})
}
}

// TestNestedRequiredFieldsNotHoisted tests the bug where nested required fields
// are incorrectly hoisted to the parent body schema's required array when path
// parameters reference nested fields
func TestNestedRequiredFieldsNotHoisted(t *testing.T) {
fieldBehaviorRequired := []annotations.FieldBehavior{annotations.FieldBehavior_REQUIRED}
requiredFieldOptions := new(descriptorpb.FieldOptions)
proto.SetExtension(requiredFieldOptions, annotations.E_FieldBehavior, fieldBehaviorRequired)

// Define the nested message (Foo) with REQUIRED fields
fooDesc := &descriptorpb.DescriptorProto{
Name: proto.String("Foo"),
Field: []*descriptorpb.FieldDescriptorProto{
{
Name: proto.String("name"),
Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(),
Number: proto.Int32(1),
Options: requiredFieldOptions, // name is REQUIRED
},
{
Name: proto.String("value"),
Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(),
Number: proto.Int32(2),
Options: requiredFieldOptions, // value is REQUIRED
},
},
}

// Define the request message (UpdateFooRequest)
updateFooReqDesc := &descriptorpb.DescriptorProto{
Name: proto.String("UpdateFooRequest"),
Field: []*descriptorpb.FieldDescriptorProto{
{
Name: proto.String("thing"),
Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
TypeName: proto.String(".test.Foo"),
Number: proto.Int32(1),
Options: requiredFieldOptions, // thing is REQUIRED
},
{
Name: proto.String("update_mask"),
Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), // Simplified - normally FieldMask
Number: proto.Int32(2),
},
},
}

fooMsg := &descriptor.Message{
DescriptorProto: fooDesc,
}

updateFooReqMsg := &descriptor.Message{
DescriptorProto: updateFooReqDesc,
}

nameField := &descriptor.Field{
Message: fooMsg,
FieldDescriptorProto: fooMsg.GetField()[0],
}
valueField := &descriptor.Field{
Message: fooMsg,
FieldDescriptorProto: fooMsg.GetField()[1],
}
fooMsg.Fields = []*descriptor.Field{nameField, valueField}

thingField := &descriptor.Field{
Message: updateFooReqMsg,
FieldMessage: fooMsg,
FieldDescriptorProto: updateFooReqMsg.GetField()[0],
}
updateMaskField := &descriptor.Field{
Message: updateFooReqMsg,
FieldDescriptorProto: updateFooReqMsg.GetField()[1],
}
updateFooReqMsg.Fields = []*descriptor.Field{thingField, updateMaskField}

meth := &descriptorpb.MethodDescriptorProto{
Name: proto.String("UpdateFoo"),
InputType: proto.String("UpdateFooRequest"),
OutputType: proto.String("Foo"),
}

svc := &descriptorpb.ServiceDescriptorProto{
Name: proto.String("FooService"),
Method: []*descriptorpb.MethodDescriptorProto{meth},
}

file := descriptor.File{
FileDescriptorProto: &descriptorpb.FileDescriptorProto{
SourceCodeInfo: &descriptorpb.SourceCodeInfo{},
Name: proto.String("foo.proto"),
Package: proto.String("test"),
MessageType: []*descriptorpb.DescriptorProto{fooDesc, updateFooReqDesc},
Service: []*descriptorpb.ServiceDescriptorProto{svc},
Options: &descriptorpb.FileOptions{
GoPackage: proto.String("github.com/example/test;test"),
},
},
GoPkg: descriptor.GoPackage{
Path: "example.com/path/to/test/test.pb",
Name: "test_pb",
},
Messages: []*descriptor.Message{fooMsg, updateFooReqMsg},
Services: []*descriptor.Service{
{
ServiceDescriptorProto: svc,
Methods: []*descriptor.Method{
{
MethodDescriptorProto: meth,
RequestType: updateFooReqMsg,
ResponseType: fooMsg,
Bindings: []*descriptor.Binding{
{
HTTPMethod: "PATCH",
PathTmpl: httprule.Template{
Version: 1,
OpCodes: []int{0, 0},
Template: "/api/v1/{thing.name}",
},
PathParams: []descriptor.Parameter{
{
FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{
{
Name: "thing",
},
{
Name: "name",
},
}),
Target: nameField,
},
},
Body: &descriptor.Body{
FieldPath: []descriptor.FieldPathComponent{}, // body: "*"
},
},
},
},
},
},
},
}

reg := descriptor.NewRegistry()
fileCL := crossLinkFixture(&file)
err := reg.Load(reqFromFile(fileCL))
if err != nil {
t.Errorf("reg.Load(%#v) failed with %v; want success", file, err)
return
}

result, err := applyTemplate(param{File: fileCL, reg: reg})
if err != nil {
t.Fatalf("applyTemplate(%#v) failed with %v; want success", file, err)
}

paths := GetPaths(result)
if got, want := len(paths), 1; got != want {
t.Fatalf("Results path length differed, got %d want %d", got, want)
}
if got, want := paths[0], "/api/v1/{thing.name}"; got != want {
t.Fatalf("Wrong results path, got %s want %s", got, want)
}

operation := *result.getPathItemObject("/api/v1/{thing.name}").Patch
if len(operation.Parameters) < 2 {
t.Fatalf("Expected at least 2 parameters, got %d", len(operation.Parameters))
}

if got, want := operation.Parameters[0].Name, "thing.name"; got != want {
t.Fatalf("Wrong parameter name 0, got %s want %s", got, want)
}
if got, want := operation.Parameters[0].In, "path"; got != want {
t.Fatalf("Wrong parameter location 0, got %s want %s", got, want)
}

if got, want := operation.Parameters[1].Name, "body"; got != want {
t.Fatalf("Wrong parameter name 1, got %s want %s", got, want)
}
if got, want := operation.Parameters[1].In, "body"; got != want {
t.Fatalf("Wrong parameter location 1, got %s want %s", got, want)
}

bodySchemaRef := operation.Parameters[1].Schema.schemaCore.Ref
if bodySchemaRef == "" {
t.Fatal("Body schema reference is empty")
}

defName := strings.TrimPrefix(bodySchemaRef, "#/definitions/")
definition, found := result.Definitions[defName]
if !found {
t.Fatalf("expecting definition to contain %s", defName)
}

// Verify that nested required fields are NOT hoisted to parent level
correctRequiredFields := []string{"thing"}
if got, want := definition.Required, correctRequiredFields; !reflect.DeepEqual(got, want) {
t.Errorf("Nested required fields were incorrectly hoisted to parent level.\n"+
"Body definition required fields:\n"+
" got = %v\n"+
" want = %v (only top-level field names)\n"+
"Nested field 'value' should be in 'thing' property's required array, not parent's.",
got, want)
}

var thingKV *keyVal
if definition.Properties != nil {
for i := range *definition.Properties {
if (*definition.Properties)[i].Key == "thing" {
thingKV = &(*definition.Properties)[i]
break
}
}
}

if thingKV == nil {
t.Fatal("'thing' property not found in body definition")
}

if _, ok := thingKV.Value.(openapiSchemaObject); !ok {
t.Fatal("'thing' property value is not an openapiSchemaObject")
}
}