diff --git a/protoc-gen-openapiv2/internal/genopenapi/template.go b/protoc-gen-openapiv2/internal/genopenapi/template.go index 6265da8fe48..d759e8f0521 100644 --- a/protoc-gen-openapiv2/internal/genopenapi/template.go +++ b/protoc-gen-openapiv2/internal/genopenapi/template.go @@ -607,11 +607,38 @@ func renderMessageAsDefinition(msg *descriptor.Message, reg *descriptor.Registry } if fieldSchema.Required != nil { - schema.Required = getUniqueFields(schema.Required, fieldSchema.Required) - schema.Required = append(schema.Required, fieldSchema.Required...) - // To avoid populating both the field schema require and message schema require, unset the field schema require. - // See issue #2635. - fieldSchema.Required = nil + // Only hoist required fields to parent if there are no path params inside this field. + if len(subPathParams) == 0 { + schema.Required = getUniqueFields(schema.Required, fieldSchema.Required) + schema.Required = append(schema.Required, fieldSchema.Required...) + // To avoid populating both the field schema require and message schema require, unset the field schema require. + // See issue #2635. + fieldSchema.Required = nil + } else { + // When there are path params, we need to separate field-level required from nested required. + // The field name itself (if required) should be in parent's required, but nested field names + // should stay in the nested schema's required. + fieldName := f.GetName() + if reg.GetUseJSONNamesForFields() { + fieldName = f.GetJsonName() + } + // Check if the field name is in the fieldSchema.Required (it would be if the field is marked REQUIRED) + var nestedRequired []string + fieldIsRequired := false + for _, req := range fieldSchema.Required { + if req == fieldName { + fieldIsRequired = true + } else { + nestedRequired = append(nestedRequired, req) + } + } + // Add the field name to parent's required if the field itself is required + if fieldIsRequired && find(schema.Required, fieldName) == -1 { + schema.Required = append(schema.Required, fieldName) + } + // Keep only the nested required fields in the field schema + fieldSchema.Required = nestedRequired + } } if reg.GetUseAllOfForRefs() { diff --git a/protoc-gen-openapiv2/internal/genopenapi/template_test.go b/protoc-gen-openapiv2/internal/genopenapi/template_test.go index a766a852430..2a296c1850a 100644 --- a/protoc-gen-openapiv2/internal/genopenapi/template_test.go +++ b/protoc-gen-openapiv2/internal/genopenapi/template_test.go @@ -11437,3 +11437,226 @@ func Test_updateSwaggerObjectFromFieldBehavior(t *testing.T) { }) } } + +// TestNestedRequiredFieldsNotHoisted tests the bug where nested required fields +// are incorrectly hoisted to the parent body schema's required array when path +// parameters reference nested fields +func TestNestedRequiredFieldsNotHoisted(t *testing.T) { + fieldBehaviorRequired := []annotations.FieldBehavior{annotations.FieldBehavior_REQUIRED} + requiredFieldOptions := new(descriptorpb.FieldOptions) + proto.SetExtension(requiredFieldOptions, annotations.E_FieldBehavior, fieldBehaviorRequired) + + // Define the nested message (Foo) with REQUIRED fields + fooDesc := &descriptorpb.DescriptorProto{ + Name: proto.String("Foo"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("name"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(1), + Options: requiredFieldOptions, // name is REQUIRED + }, + { + Name: proto.String("value"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Number: proto.Int32(2), + Options: requiredFieldOptions, // value is REQUIRED + }, + }, + } + + // Define the request message (UpdateFooRequest) + updateFooReqDesc := &descriptorpb.DescriptorProto{ + Name: proto.String("UpdateFooRequest"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("thing"), + Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), + TypeName: proto.String(".test.Foo"), + Number: proto.Int32(1), + Options: requiredFieldOptions, // thing is REQUIRED + }, + { + Name: proto.String("update_mask"), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), // Simplified - normally FieldMask + Number: proto.Int32(2), + }, + }, + } + + fooMsg := &descriptor.Message{ + DescriptorProto: fooDesc, + } + + updateFooReqMsg := &descriptor.Message{ + DescriptorProto: updateFooReqDesc, + } + + nameField := &descriptor.Field{ + Message: fooMsg, + FieldDescriptorProto: fooMsg.GetField()[0], + } + valueField := &descriptor.Field{ + Message: fooMsg, + FieldDescriptorProto: fooMsg.GetField()[1], + } + fooMsg.Fields = []*descriptor.Field{nameField, valueField} + + thingField := &descriptor.Field{ + Message: updateFooReqMsg, + FieldMessage: fooMsg, + FieldDescriptorProto: updateFooReqMsg.GetField()[0], + } + updateMaskField := &descriptor.Field{ + Message: updateFooReqMsg, + FieldDescriptorProto: updateFooReqMsg.GetField()[1], + } + updateFooReqMsg.Fields = []*descriptor.Field{thingField, updateMaskField} + + meth := &descriptorpb.MethodDescriptorProto{ + Name: proto.String("UpdateFoo"), + InputType: proto.String("UpdateFooRequest"), + OutputType: proto.String("Foo"), + } + + svc := &descriptorpb.ServiceDescriptorProto{ + Name: proto.String("FooService"), + Method: []*descriptorpb.MethodDescriptorProto{meth}, + } + + file := descriptor.File{ + FileDescriptorProto: &descriptorpb.FileDescriptorProto{ + SourceCodeInfo: &descriptorpb.SourceCodeInfo{}, + Name: proto.String("foo.proto"), + Package: proto.String("test"), + MessageType: []*descriptorpb.DescriptorProto{fooDesc, updateFooReqDesc}, + Service: []*descriptorpb.ServiceDescriptorProto{svc}, + Options: &descriptorpb.FileOptions{ + GoPackage: proto.String("github.com/example/test;test"), + }, + }, + GoPkg: descriptor.GoPackage{ + Path: "example.com/path/to/test/test.pb", + Name: "test_pb", + }, + Messages: []*descriptor.Message{fooMsg, updateFooReqMsg}, + Services: []*descriptor.Service{ + { + ServiceDescriptorProto: svc, + Methods: []*descriptor.Method{ + { + MethodDescriptorProto: meth, + RequestType: updateFooReqMsg, + ResponseType: fooMsg, + Bindings: []*descriptor.Binding{ + { + HTTPMethod: "PATCH", + PathTmpl: httprule.Template{ + Version: 1, + OpCodes: []int{0, 0}, + Template: "/api/v1/{thing.name}", + }, + PathParams: []descriptor.Parameter{ + { + FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ + { + Name: "thing", + }, + { + Name: "name", + }, + }), + Target: nameField, + }, + }, + Body: &descriptor.Body{ + FieldPath: []descriptor.FieldPathComponent{}, // body: "*" + }, + }, + }, + }, + }, + }, + }, + } + + reg := descriptor.NewRegistry() + fileCL := crossLinkFixture(&file) + err := reg.Load(reqFromFile(fileCL)) + if err != nil { + t.Errorf("reg.Load(%#v) failed with %v; want success", file, err) + return + } + + result, err := applyTemplate(param{File: fileCL, reg: reg}) + if err != nil { + t.Fatalf("applyTemplate(%#v) failed with %v; want success", file, err) + } + + paths := GetPaths(result) + if got, want := len(paths), 1; got != want { + t.Fatalf("Results path length differed, got %d want %d", got, want) + } + if got, want := paths[0], "/api/v1/{thing.name}"; got != want { + t.Fatalf("Wrong results path, got %s want %s", got, want) + } + + operation := *result.getPathItemObject("/api/v1/{thing.name}").Patch + if len(operation.Parameters) < 2 { + t.Fatalf("Expected at least 2 parameters, got %d", len(operation.Parameters)) + } + + if got, want := operation.Parameters[0].Name, "thing.name"; got != want { + t.Fatalf("Wrong parameter name 0, got %s want %s", got, want) + } + if got, want := operation.Parameters[0].In, "path"; got != want { + t.Fatalf("Wrong parameter location 0, got %s want %s", got, want) + } + + if got, want := operation.Parameters[1].Name, "body"; got != want { + t.Fatalf("Wrong parameter name 1, got %s want %s", got, want) + } + if got, want := operation.Parameters[1].In, "body"; got != want { + t.Fatalf("Wrong parameter location 1, got %s want %s", got, want) + } + + bodySchemaRef := operation.Parameters[1].Schema.schemaCore.Ref + if bodySchemaRef == "" { + t.Fatal("Body schema reference is empty") + } + + defName := strings.TrimPrefix(bodySchemaRef, "#/definitions/") + definition, found := result.Definitions[defName] + if !found { + t.Fatalf("expecting definition to contain %s", defName) + } + + // Verify that nested required fields are NOT hoisted to parent level + correctRequiredFields := []string{"thing"} + if got, want := definition.Required, correctRequiredFields; !reflect.DeepEqual(got, want) { + t.Errorf("Nested required fields were incorrectly hoisted to parent level.\n"+ + "Body definition required fields:\n"+ + " got = %v\n"+ + " want = %v (only top-level field names)\n"+ + "Nested field 'value' should be in 'thing' property's required array, not parent's.", + got, want) + } + + var thingKV *keyVal + if definition.Properties != nil { + for i := range *definition.Properties { + if (*definition.Properties)[i].Key == "thing" { + thingKV = &(*definition.Properties)[i] + break + } + } + } + + if thingKV == nil { + t.Fatal("'thing' property not found in body definition") + } + + if _, ok := thingKV.Value.(openapiSchemaObject); !ok { + t.Fatal("'thing' property value is not an openapiSchemaObject") + } +}