From 29845743f72233f692034ad580caef61bb253ce3 Mon Sep 17 00:00:00 2001 From: Robert Yokota Date: Thu, 3 Jul 2025 10:00:12 -0700 Subject: [PATCH] Refactor encryptor --- .../rules/encryption/encrypt_executor.go | 75 +++++++------ .../rules/encryption/encrypt_executor_test.go | 2 +- .../encryption/field_encrypt_executor.go | 86 +++++++++++++++ schemaregistry/schemaregistry_client.go | 32 +++++- schemaregistry/serde/avrov2/avro.go | 15 ++- schemaregistry/serde/avrov2/avro_test.go | 103 +++++++++++++++--- .../serde/jsonschema/json_schema.go | 15 ++- .../serde/jsonschema/json_schema_test.go | 69 ++++++++++++ schemaregistry/serde/protobuf/protobuf.go | 15 ++- .../serde/protobuf/protobuf_test.go | 75 +++++++++++++ schemaregistry/serde/serde.go | 18 ++- 11 files changed, 439 insertions(+), 66 deletions(-) create mode 100644 schemaregistry/rules/encryption/field_encrypt_executor.go diff --git a/schemaregistry/rules/encryption/encrypt_executor.go b/schemaregistry/rules/encryption/encrypt_executor.go index b5449c900..4c34ee920 100644 --- a/schemaregistry/rules/encryption/encrypt_executor.go +++ b/schemaregistry/rules/encryption/encrypt_executor.go @@ -44,10 +44,11 @@ func init() { // Register registers the encryption rule executor func Register() { serde.RegisterRuleExecutor(NewExecutor()) + serde.RegisterRuleExecutor(NewFieldExecutor()) } -// RegisterWithClock registers the encryption rule executor with a given clock -func RegisterWithClock(c Clock) *FieldEncryptionExecutor { +// RegisterExecutorWithClock registers the encryption rule executor with a given clock +func RegisterExecutorWithClock(c Clock) *Executor { f := NewExecutorWithClock(c) serde.RegisterRuleExecutor(f) return f @@ -60,10 +61,8 @@ func NewExecutor() serde.RuleExecutor { } // NewExecutorWithClock creates a new encryption rule executor with a given clock -func NewExecutorWithClock(c Clock) *FieldEncryptionExecutor { - a := &serde.AbstractFieldRuleExecutor{} - f := &FieldEncryptionExecutor{*a, nil, nil, c} - f.FieldRuleExecutor = f +func NewExecutorWithClock(c Clock) *Executor { + f := &Executor{nil, nil, c} return f } @@ -101,16 +100,15 @@ func (*clock) NowUnixMilli() int64 { return time.Now().UnixMilli() } -// FieldEncryptionExecutor is a field encryption executor -type FieldEncryptionExecutor struct { - serde.AbstractFieldRuleExecutor +// Executor is an encryption executor +type Executor struct { Config map[string]string Client deks.Client Clock Clock } // Configure configures the executor -func (f *FieldEncryptionExecutor) Configure(clientConfig *schemaregistry.Config, config map[string]string) error { +func (f *Executor) Configure(clientConfig *schemaregistry.Config, config map[string]string) error { if f.Client != nil { if !schemaregistry.ConfigsEqual(f.Client.Config(), clientConfig) { return errors.New("executor already configured") @@ -143,12 +141,21 @@ func (f *FieldEncryptionExecutor) Configure(clientConfig *schemaregistry.Config, } // Type returns the type of the executor -func (f *FieldEncryptionExecutor) Type() string { - return "ENCRYPT" +func (f *Executor) Type() string { + return "ENCRYPT_PAYLOAD" +} + +// Transform transforms the message using the rule +func (f *Executor) Transform(ctx serde.RuleContext, msg interface{}) (interface{}, error) { + transform, err := f.NewTransform(ctx) + if err != nil { + return nil, err + } + return transform.Transform(ctx, serde.TypeBytes, msg) } // NewTransform creates a new transform -func (f *FieldEncryptionExecutor) NewTransform(ctx serde.RuleContext) (serde.FieldTransform, error) { +func (f *Executor) NewTransform(ctx serde.RuleContext) (*ExecutorTransform, error) { kekName, err := getKekName(ctx) if err != nil { return nil, err @@ -157,7 +164,7 @@ func (f *FieldEncryptionExecutor) NewTransform(ctx serde.RuleContext) (serde.Fie if err != nil { return nil, err } - transform := FieldEncryptionExecutorTransform{ + transform := ExecutorTransform{ Executor: *f, Cryptor: getCryptor(ctx), KekName: kekName, @@ -172,13 +179,13 @@ func (f *FieldEncryptionExecutor) NewTransform(ctx serde.RuleContext) (serde.Fie } // Close closes the executor -func (f *FieldEncryptionExecutor) Close() error { +func (f *Executor) Close() error { return f.Client.Close() } -// FieldEncryptionExecutorTransform is a field encryption executor transform -type FieldEncryptionExecutorTransform struct { - Executor FieldEncryptionExecutor +// ExecutorTransform is a field encryption executor transform +type ExecutorTransform struct { + Executor Executor Cryptor Cryptor KekName string Kek deks.Kek @@ -290,11 +297,11 @@ func getDekExpiryDays(ctx serde.RuleContext) (int, error) { return i, nil } -func (f *FieldEncryptionExecutorTransform) isDekRotated() bool { +func (f *ExecutorTransform) isDekRotated() bool { return f.DekExpiryDays > 0 } -func (f *FieldEncryptionExecutorTransform) getOrCreateKek(ctx serde.RuleContext) (*deks.Kek, error) { +func (f *ExecutorTransform) getOrCreateKek(ctx serde.RuleContext) (*deks.Kek, error) { isRead := ctx.RuleMode == schemaregistry.Read kekID := deks.KekID{ Name: f.KekName, @@ -334,7 +341,7 @@ func (f *FieldEncryptionExecutorTransform) getOrCreateKek(ctx serde.RuleContext) return kek, nil } -func (f *FieldEncryptionExecutorTransform) retrieveKekFromRegistry(key deks.KekID) (*deks.Kek, error) { +func (f *ExecutorTransform) retrieveKekFromRegistry(key deks.KekID) (*deks.Kek, error) { kek, err := f.Executor.Client.GetKek(key.Name, key.Deleted) if err != nil { var restErr *rest.Error @@ -348,7 +355,7 @@ func (f *FieldEncryptionExecutorTransform) retrieveKekFromRegistry(key deks.KekI return &kek, nil } -func (f *FieldEncryptionExecutorTransform) storeKekToRegistry(key deks.KekID, kmsType string, kmsKeyID string, shared bool) (*deks.Kek, error) { +func (f *ExecutorTransform) storeKekToRegistry(key deks.KekID, kmsType string, kmsKeyID string, shared bool) (*deks.Kek, error) { kek, err := f.Executor.Client.RegisterKek(key.Name, kmsType, kmsKeyID, nil, "", shared) if err != nil { var restErr *rest.Error @@ -362,7 +369,7 @@ func (f *FieldEncryptionExecutorTransform) storeKekToRegistry(key deks.KekID, km return &kek, nil } -func (f *FieldEncryptionExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int) (*deks.Dek, error) { +func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int) (*deks.Dek, error) { isRead := ctx.RuleMode == schemaregistry.Read ver := 1 if version != nil { @@ -442,7 +449,7 @@ func (f *FieldEncryptionExecutorTransform) getOrCreateDek(ctx serde.RuleContext, return dek, nil } -func (f *FieldEncryptionExecutorTransform) createDek(dekID deks.DekID, newVersion int, encryptedDek []byte) (*deks.Dek, error) { +func (f *ExecutorTransform) createDek(dekID deks.DekID, newVersion int, encryptedDek []byte) (*deks.Dek, error) { newDekID := deks.DekID{ KekName: dekID.KekName, Subject: dekID.Subject, @@ -466,7 +473,7 @@ func (f *FieldEncryptionExecutorTransform) createDek(dekID deks.DekID, newVersio return dek, nil } -func (f *FieldEncryptionExecutorTransform) retrieveDekFromRegistry(key deks.DekID) (*deks.Dek, error) { +func (f *ExecutorTransform) retrieveDekFromRegistry(key deks.DekID) (*deks.Dek, error) { var dek deks.Dek var err error if key.Version != 0 { @@ -486,7 +493,7 @@ func (f *FieldEncryptionExecutorTransform) retrieveDekFromRegistry(key deks.DekI return &dek, nil } -func (f *FieldEncryptionExecutorTransform) storeDekToRegistry(key deks.DekID, encryptedDek []byte) (*deks.Dek, error) { +func (f *ExecutorTransform) storeDekToRegistry(key deks.DekID, encryptedDek []byte) (*deks.Dek, error) { var encryptedDekStr string if encryptedDek != nil { encryptedDekStr = base64.StdEncoding.EncodeToString(encryptedDek) @@ -510,7 +517,7 @@ func (f *FieldEncryptionExecutorTransform) storeDekToRegistry(key deks.DekID, en return &dek, nil } -func (f *FieldEncryptionExecutorTransform) isExpired(ctx serde.RuleContext, dek *deks.Dek) bool { +func (f *ExecutorTransform) isExpired(ctx serde.RuleContext, dek *deks.Dek) bool { now := f.Executor.Clock.NowUnixMilli() return ctx.RuleMode != schemaregistry.Read && f.DekExpiryDays > 0 && @@ -519,15 +526,15 @@ func (f *FieldEncryptionExecutorTransform) isExpired(ctx serde.RuleContext, dek } // Transform transforms the field value using the rule -func (f *FieldEncryptionExecutorTransform) Transform(ctx serde.RuleContext, fieldCtx serde.FieldContext, fieldValue interface{}) (interface{}, error) { +func (f *ExecutorTransform) Transform(ctx serde.RuleContext, fieldType serde.FieldType, fieldValue interface{}) (interface{}, error) { if fieldValue == nil { return nil, nil } switch ctx.RuleMode { case schemaregistry.Write: - plaintext := toBytes(fieldCtx.Type, fieldValue) + plaintext := toBytes(fieldType, fieldValue) if plaintext == nil { - return nil, fmt.Errorf("type '%v' not supported for encryption", fieldCtx.Type) + return nil, fmt.Errorf("type '%v' not supported for encryption", fieldType) } var version *int if f.isDekRotated() { @@ -552,16 +559,16 @@ func (f *FieldEncryptionExecutorTransform) Transform(ctx serde.RuleContext, fiel return nil, err } } - if fieldCtx.Type == serde.TypeString { + if fieldType == serde.TypeString { return base64.StdEncoding.EncodeToString(ciphertext), nil } return ciphertext, nil case schemaregistry.Read: - ciphertext := toBytes(fieldCtx.Type, fieldValue) + ciphertext := toBytes(fieldType, fieldValue) if ciphertext == nil { return fieldValue, nil } - if fieldCtx.Type == serde.TypeString { + if fieldType == serde.TypeString { var err error ciphertext, err = base64.StdEncoding.DecodeString(string(ciphertext)) if err != nil { @@ -589,7 +596,7 @@ func (f *FieldEncryptionExecutorTransform) Transform(ctx serde.RuleContext, fiel if err != nil { return nil, err } - return toObject(fieldCtx.Type, plaintext), nil + return toObject(fieldType, plaintext), nil default: return nil, fmt.Errorf("unsupported rule mode %v", ctx.RuleMode) } diff --git a/schemaregistry/rules/encryption/encrypt_executor_test.go b/schemaregistry/rules/encryption/encrypt_executor_test.go index 8c89fbb6f..9eab95056 100644 --- a/schemaregistry/rules/encryption/encrypt_executor_test.go +++ b/schemaregistry/rules/encryption/encrypt_executor_test.go @@ -25,7 +25,7 @@ import ( "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry" ) -func TestFieldEncryptionExecutor_Configure(t *testing.T) { +func TestEncryptionExecutor_Configure(t *testing.T) { maybeFail = initFailFunc(t) executor := NewExecutor() diff --git a/schemaregistry/rules/encryption/field_encrypt_executor.go b/schemaregistry/rules/encryption/field_encrypt_executor.go new file mode 100644 index 000000000..6ea0cfebd --- /dev/null +++ b/schemaregistry/rules/encryption/field_encrypt_executor.go @@ -0,0 +1,86 @@ +/** + * Copyright 2024 Confluent Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package encryption + +import ( + "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry" + "github.com/confluentinc/confluent-kafka-go/v2/schemaregistry/serde" +) + +// RegisterFieldExecutorWithClock registers the encryption rule executor with a given clock +func RegisterFieldExecutorWithClock(c Clock) *FieldEncryptionExecutor { + f := NewFieldExecutorWithClock(c) + serde.RegisterRuleExecutor(f) + return f +} + +// NewFieldExecutor creates a new encryption rule executor +func NewFieldExecutor() serde.RuleExecutor { + c := clock{} + return NewFieldExecutorWithClock(&c) +} + +// NewFieldExecutorWithClock creates a new encryption rule executor with a given clock +func NewFieldExecutorWithClock(c Clock) *FieldEncryptionExecutor { + a := &serde.AbstractFieldRuleExecutor{} + f := &FieldEncryptionExecutor{*a, *NewExecutorWithClock(c)} + f.FieldRuleExecutor = f + return f +} + +// FieldEncryptionExecutor is a field encryption executor +type FieldEncryptionExecutor struct { + serde.AbstractFieldRuleExecutor + Executor Executor +} + +// Configure configures the executor +func (f *FieldEncryptionExecutor) Configure(clientConfig *schemaregistry.Config, config map[string]string) error { + return f.Executor.Configure(clientConfig, config) +} + +// Type returns the type of the executor +func (f *FieldEncryptionExecutor) Type() string { + return "ENCRYPT" +} + +// NewTransform creates a new transform +func (f *FieldEncryptionExecutor) NewTransform(ctx serde.RuleContext) (serde.FieldTransform, error) { + executorTransform, err := f.Executor.NewTransform(ctx) + if err != nil { + return nil, err + } + transform := FieldEncryptionExecutorTransform{ + ExecutorTransform: *executorTransform, + } + return &transform, nil +} + +// Close closes the executor +func (f *FieldEncryptionExecutor) Close() error { + return f.Executor.Close() +} + +// FieldEncryptionExecutorTransform is a field encryption executor transform +type FieldEncryptionExecutorTransform struct { + ExecutorTransform ExecutorTransform +} + +// Transform transforms the field value using the rule +func (f *FieldEncryptionExecutorTransform) Transform(ctx serde.RuleContext, fieldCtx serde.FieldContext, fieldValue interface{}) (interface{}, error) { + return f.ExecutorTransform.Transform(ctx, fieldCtx.Type, fieldValue) +} diff --git a/schemaregistry/schemaregistry_client.go b/schemaregistry/schemaregistry_client.go index dfae73f19..d3d9c83d5 100644 --- a/schemaregistry/schemaregistry_client.go +++ b/schemaregistry/schemaregistry_client.go @@ -101,6 +101,18 @@ type Rule struct { Disabled bool `json:"disabled,omitempty"` } +// RulePhase represents the rule phase +type RulePhase = int + +const ( + // MigrationPhase denotes migration phase + MigrationPhase = 1 + // DomainPhase denotes domain phase + DomainPhase = 2 + // EncodingPhase denotes encoding phase + EncodingPhase = 3 +) + // RuleMode represents the rule mode type RuleMode = int @@ -138,25 +150,35 @@ func ParseMode(mode string) (RuleMode, bool) { type RuleSet struct { MigrationRules []Rule `json:"migrationRules,omitempty"` DomainRules []Rule `json:"domainRules,omitempty"` + EncodingRules []Rule `json:"encodingRules,omitempty"` } // HasRules checks if the ruleset has rules for the given mode -func (r *RuleSet) HasRules(mode RuleMode) bool { +func (r *RuleSet) HasRules(phase RulePhase, mode RuleMode) bool { + var rules []Rule + switch phase { + case MigrationPhase: + rules = r.MigrationRules + case DomainPhase: + rules = r.DomainRules + case EncodingPhase: + rules = r.EncodingRules + } switch mode { case Upgrade, Downgrade: - return r.hasRules(r.MigrationRules, func(ruleMode RuleMode) bool { + return r.hasRules(rules, func(ruleMode RuleMode) bool { return ruleMode == mode || ruleMode == UpDown }) case UpDown: - return r.hasRules(r.MigrationRules, func(ruleMode RuleMode) bool { + return r.hasRules(rules, func(ruleMode RuleMode) bool { return ruleMode == mode }) case Write, Read: - return r.hasRules(r.DomainRules, func(ruleMode RuleMode) bool { + return r.hasRules(rules, func(ruleMode RuleMode) bool { return ruleMode == mode || ruleMode == WriteRead }) case WriteRead: - return r.hasRules(r.DomainRules, func(ruleMode RuleMode) bool { + return r.hasRules(rules, func(ruleMode RuleMode) bool { return ruleMode == mode }) } diff --git a/schemaregistry/serde/avrov2/avro.go b/schemaregistry/serde/avrov2/avro.go index 84072e9e9..b1a8f8cbc 100644 --- a/schemaregistry/serde/avrov2/avro.go +++ b/schemaregistry/serde/avrov2/avro.go @@ -143,7 +143,12 @@ func (s *Serializer) SerializeWithHeaders(topic string, msg interface{}) ([]kafk if err != nil { return nil, nil, err } - return s.SchemaIDSerializer(topic, s.SerdeType, msgBytes, schemaID) + msg, err = s.ExecuteRulesWithPhase(subject, topic, + schemaregistry.EncodingPhase, schemaregistry.Write, nil, &info, msgBytes) + if err != nil { + return nil, nil, err + } + return s.SchemaIDSerializer(topic, s.SerdeType, msg.([]byte), schemaID) } // NewDeserializer creates an Avro deserializer for generic objects @@ -210,6 +215,13 @@ func (s *Deserializer) deserialize(topic string, headers []kafka.Header, payload if err != nil { return nil, err } + var msg interface{} + msg, err = s.ExecuteRulesWithPhase(subject, topic, + schemaregistry.EncodingPhase, schemaregistry.Read, nil, &info, payload) + if err != nil { + return nil, err + } + payload = msg.([]byte) readerMeta, err := s.GetReaderSchema(subject) if err != nil { return nil, err @@ -225,7 +237,6 @@ func (s *Deserializer) deserialize(topic string, headers []kafka.Header, payload if err != nil { return nil, err } - var msg interface{} if len(migrations) > 0 { err = s.api.Unmarshal(writer, payload, &msg) if err != nil { diff --git a/schemaregistry/serde/avrov2/avro_test.go b/schemaregistry/serde/avrov2/avro_test.go index 0b8e377c4..fae815a17 100644 --- a/schemaregistry/serde/avrov2/avro_test.go +++ b/schemaregistry/serde/avrov2/avro_test.go @@ -1585,6 +1585,75 @@ func TestAvroSerdeEncryption(t *testing.T) { serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) } +func TestAvroSerdePayloadEncryption(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + serConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-encrypt", + Kind: "TRANSFORM", + Mode: "WRITEREAD", + Type: "ENCRYPT_PAYLOAD", + Params: map[string]string{ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey", + }, + OnFailure: "ERROR,NONE", + } + ruleSet := schemaregistry.RuleSet{ + EncodingRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: demoSchema, + SchemaType: "AVRO", + RuleSet: &ruleSet, + } + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + obj := DemoSchema{} + obj.IntField = 123 + obj.DoubleField = 45.67 + obj.StringField = "hi" + obj.BoolField = true + obj.BytesField = []byte{1, 2} + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + deserConfig := NewDeserializerConfig() + deserConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + deser.MessageFactory = testMessageFactory + + newobj, err := deser.Deserialize("topic1", bytes) + serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) +} + func TestAvroSerdeEncryptionDeterministic(t *testing.T) { serde.MaybeFail = serde.InitFailFunc(t) var err error @@ -1737,7 +1806,7 @@ func TestAvroSerdeEncryptionWithSimpleMap(t *testing.T) { func TestAvroSerdeEncryptionDekRotation(t *testing.T) { f := fakeClock{now: time.Now().UnixMilli()} - executor := encryption.RegisterWithClock(&f) + executor := encryption.RegisterFieldExecutorWithClock(&f) serde.MaybeFail = serde.InitFailFunc(t) var err error @@ -1811,7 +1880,7 @@ func TestAvroSerdeEncryptionDekRotation(t *testing.T) { newobj, err := deser.Deserialize("topic1", bytes) serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) - dek, err := executor.Client.GetDekVersion( + dek, err := executor.Executor.Client.GetDekVersion( "kek1", "topic1-value", -1, "AES256_GCM", false) serde.MaybeFail("DEK retrieval", err, serde.Expect(dek.Version, 1)) @@ -1834,7 +1903,7 @@ func TestAvroSerdeEncryptionDekRotation(t *testing.T) { newobj, err = deser.Deserialize("topic1", bytes) serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) - dek, err = executor.Client.GetDekVersion( + dek, err = executor.Executor.Client.GetDekVersion( "kek1", "topic1-value", -1, "AES256_GCM", false) serde.MaybeFail("DEK retrieval", err, serde.Expect(dek.Version, 2)) @@ -1857,16 +1926,16 @@ func TestAvroSerdeEncryptionDekRotation(t *testing.T) { newobj, err = deser.Deserialize("topic1", bytes) serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) - dek, err = executor.Client.GetDekVersion( + dek, err = executor.Executor.Client.GetDekVersion( "kek1", "topic1-value", -1, "AES256_GCM", false) serde.MaybeFail("DEK retrieval", err, serde.Expect(dek.Version, 3)) - executor.Client.Close() + executor.Executor.Client.Close() } func TestAvroSerdeEncryptionF1Preserialized(t *testing.T) { c := clock{} - executor := encryption.RegisterWithClock(&c) + executor := encryption.RegisterFieldExecutorWithClock(&c) serde.MaybeFail = serde.InitFailFunc(t) var err error @@ -1916,23 +1985,23 @@ func TestAvroSerdeEncryptionF1Preserialized(t *testing.T) { serde.MaybeFail("Deserializer configuration", err) deser.MessageFactory = testMessageFactory - executor.Client.RegisterKek("kek1", "local-kms", "mykey", make(map[string]string), "", false) + executor.Executor.Client.RegisterKek("kek1", "local-kms", "mykey", make(map[string]string), "", false) serde.MaybeFail("Kek registration", err) encryptedDek := "07V2ndh02DA73p+dTybwZFm7DKQSZN1tEwQh+FoX1DZLk4Yj2LLu4omYjp/84tAg3BYlkfGSz+zZacJHIE4=" - executor.Client.RegisterDek("kek1", "topic1-value", "AES256_GCM", encryptedDek) + executor.Executor.Client.RegisterDek("kek1", "topic1-value", "AES256_GCM", encryptedDek) serde.MaybeFail("Dek registration", err) bytes := []byte{0, 0, 0, 0, 1, 104, 122, 103, 121, 47, 106, 70, 78, 77, 86, 47, 101, 70, 105, 108, 97, 72, 114, 77, 121, 101, 66, 103, 100, 97, 86, 122, 114, 82, 48, 117, 100, 71, 101, 111, 116, 87, 56, 99, 65, 47, 74, 97, 108, 55, 117, 107, 114, 43, 77, 47, 121, 122} newobj, err := deser.Deserialize("topic1", bytes) serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) - executor.Client.Close() + executor.Executor.Client.Close() } func TestAvroSerdeEncryptionDeterministicF1Preserialized(t *testing.T) { c := clock{} - executor := encryption.RegisterWithClock(&c) + executor := encryption.RegisterFieldExecutorWithClock(&c) serde.MaybeFail = serde.InitFailFunc(t) var err error @@ -1983,23 +2052,23 @@ func TestAvroSerdeEncryptionDeterministicF1Preserialized(t *testing.T) { serde.MaybeFail("Deserializer configuration", err) deser.MessageFactory = testMessageFactory - executor.Client.RegisterKek("kek1", "local-kms", "mykey", make(map[string]string), "", false) + executor.Executor.Client.RegisterKek("kek1", "local-kms", "mykey", make(map[string]string), "", false) serde.MaybeFail("Kek registration", err) encryptedDek := "YSx3DTlAHrmpoDChquJMifmPntBzxgRVdMzgYL82rgWBKn7aUSnG+WIu9ozBNS3y2vXd++mBtK07w4/W/G6w0da39X9hfOVZsGnkSvry/QRht84V8yz3dqKxGMOK5A==" - executor.Client.RegisterDek("kek1", "topic1-value", "AES256_SIV", encryptedDek) + executor.Executor.Client.RegisterDek("kek1", "topic1-value", "AES256_SIV", encryptedDek) serde.MaybeFail("Dek registration", err) bytes := []byte{0, 0, 0, 0, 1, 72, 68, 54, 89, 116, 120, 114, 108, 66, 110, 107, 84, 87, 87, 57, 78, 54, 86, 98, 107, 51, 73, 73, 110, 106, 87, 72, 56, 49, 120, 109, 89, 104, 51, 107, 52, 100} newobj, err := deser.Deserialize("topic1", bytes) serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) - executor.Client.Close() + executor.Executor.Client.Close() } func TestAvroSerdeEncryptionDekRotationF1Preserialized(t *testing.T) { c := clock{} - executor := encryption.RegisterWithClock(&c) + executor := encryption.RegisterFieldExecutorWithClock(&c) serde.MaybeFail = serde.InitFailFunc(t) var err error @@ -2050,18 +2119,18 @@ func TestAvroSerdeEncryptionDekRotationF1Preserialized(t *testing.T) { serde.MaybeFail("Deserializer configuration", err) deser.MessageFactory = testMessageFactory - executor.Client.RegisterKek("kek1", "local-kms", "mykey", make(map[string]string), "", false) + executor.Executor.Client.RegisterKek("kek1", "local-kms", "mykey", make(map[string]string), "", false) serde.MaybeFail("Kek registration", err) encryptedDek := "W/v6hOQYq1idVAcs1pPWz9UUONMVZW4IrglTnG88TsWjeCjxmtRQ4VaNe/I5dCfm2zyY9Cu0nqdvqImtUk4=" - executor.Client.RegisterDek("kek1", "topic1-value", "AES256_GCM", encryptedDek) + executor.Executor.Client.RegisterDek("kek1", "topic1-value", "AES256_GCM", encryptedDek) serde.MaybeFail("Dek registration", err) bytes := []byte{0, 0, 0, 0, 1, 120, 65, 65, 65, 65, 65, 65, 71, 52, 72, 73, 54, 98, 49, 110, 88, 80, 88, 113, 76, 121, 71, 56, 99, 73, 73, 51, 53, 78, 72, 81, 115, 101, 113, 113, 85, 67, 100, 43, 73, 101, 76, 101, 70, 86, 65, 101, 78, 112, 83, 83, 51, 102, 120, 80, 110, 74, 51, 50, 65, 61} newobj, err := deser.Deserialize("topic1", bytes) serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) - executor.Client.Close() + executor.Executor.Client.Close() } func TestAvroSerdeEncryptionWithReferences(t *testing.T) { diff --git a/schemaregistry/serde/jsonschema/json_schema.go b/schemaregistry/serde/jsonschema/json_schema.go index c7c7c92a3..8ae8bc750 100644 --- a/schemaregistry/serde/jsonschema/json_schema.go +++ b/schemaregistry/serde/jsonschema/json_schema.go @@ -148,7 +148,12 @@ func (s *Serializer) SerializeWithHeaders(topic string, msg interface{}) ([]kafk return nil, nil, err } } - return s.SchemaIDSerializer(topic, s.SerdeType, raw, schemaID) + msg, err = s.ExecuteRulesWithPhase(subject, topic, + schemaregistry.EncodingPhase, schemaregistry.Write, nil, &info, raw) + if err != nil { + return nil, nil, err + } + return s.SchemaIDSerializer(topic, s.SerdeType, msg.([]byte), schemaID) } // NewDeserializer creates a JSON deserializer for generic objects @@ -214,6 +219,13 @@ func (s *Deserializer) deserialize(topic string, headers []kafka.Header, payload if err != nil { return nil, err } + var msg interface{} + msg, err = s.ExecuteRulesWithPhase(subject, topic, + schemaregistry.EncodingPhase, schemaregistry.Read, nil, &info, payload) + if err != nil { + return nil, err + } + payload = msg.([]byte) readerMeta, err := s.GetReaderSchema(subject) if err != nil { return nil, err @@ -225,7 +237,6 @@ func (s *Deserializer) deserialize(topic string, headers []kafka.Header, payload return nil, err } } - var msg interface{} bytes := payload if len(migrations) > 0 { err = json.Unmarshal(bytes, &msg) diff --git a/schemaregistry/serde/jsonschema/json_schema_test.go b/schemaregistry/serde/jsonschema/json_schema_test.go index 5e58b0643..1d0b674aa 100644 --- a/schemaregistry/serde/jsonschema/json_schema_test.go +++ b/schemaregistry/serde/jsonschema/json_schema_test.go @@ -1246,6 +1246,75 @@ func TestJSONSchemaSerdeEncryption(t *testing.T) { serde.MaybeFail("deserialization", err, serde.Expect(&newobj, &obj)) } +func TestJSONSchemaSerdePayloadEncryption(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + serConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-encrypt", + Kind: "TRANSFORM", + Mode: "WRITEREAD", + Type: "ENCRYPT_PAYLOAD", + Params: map[string]string{ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey", + }, + OnFailure: "ERROR,NONE", + } + ruleSet := schemaregistry.RuleSet{ + EncodingRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: demoSchema, + SchemaType: "JSON", + RuleSet: &ruleSet, + } + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + obj := JSONDemoSchema{} + obj.IntField = 123 + obj.DoubleField = 45.67 + obj.StringField = "hi" + obj.BoolField = true + obj.BytesField = base64.StdEncoding.EncodeToString([]byte{1, 2}) + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + deserConfig := NewDeserializerConfig() + deserConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + + var newobj JSONDemoSchema + err = deser.DeserializeInto("topic1", bytes, &newobj) + serde.MaybeFail("deserialization", err, serde.Expect(&newobj, &obj)) +} + func TestJSONSchemaSerdeEncryptionWithUnion(t *testing.T) { serde.MaybeFail = serde.InitFailFunc(t) var err error diff --git a/schemaregistry/serde/protobuf/protobuf.go b/schemaregistry/serde/protobuf/protobuf.go index f6808fbbc..7c700b9b5 100644 --- a/schemaregistry/serde/protobuf/protobuf.go +++ b/schemaregistry/serde/protobuf/protobuf.go @@ -248,7 +248,12 @@ func (s *Serializer) SerializeWithHeaders(topic string, msg interface{}) ([]kafk if err != nil { return nil, nil, err } - return s.SchemaIDSerializer(topic, s.SerdeType, msgBytes, schemaID) + msg, err = s.ExecuteRulesWithPhase(subject, topic, + schemaregistry.EncodingPhase, schemaregistry.Write, nil, &info, msgBytes) + if err != nil { + return nil, nil, err + } + return s.SchemaIDSerializer(topic, s.SerdeType, msg.([]byte), schemaID) } func (s *Serializer) getSchemaInfo(protoMsg proto.Message) (*schemaregistry.SchemaInfo, error) { @@ -563,6 +568,13 @@ func (s *Deserializer) deserialize(topic string, headers []kafka.Header, payload if err != nil { return nil, err } + var msg interface{} + msg, err = s.ExecuteRulesWithPhase(subject, topic, + schemaregistry.EncodingPhase, schemaregistry.Read, nil, &info, payload) + if err != nil { + return nil, err + } + payload = msg.([]byte) readerMeta, err := s.GetReaderSchema(subject) if err != nil { return nil, err @@ -574,7 +586,6 @@ func (s *Deserializer) deserialize(topic string, headers []kafka.Header, payload return nil, err } } - var msg interface{} var protoMsg proto.Message if len(migrations) > 0 { dynamicMsg := dynamicpb.NewMessage(messageDesc.UnwrapMessage()) diff --git a/schemaregistry/serde/protobuf/protobuf_test.go b/schemaregistry/serde/protobuf/protobuf_test.go index f2cfa681a..a32e75ea7 100644 --- a/schemaregistry/serde/protobuf/protobuf_test.go +++ b/schemaregistry/serde/protobuf/protobuf_test.go @@ -718,6 +718,81 @@ func TestProtobufSerdeEncryption(t *testing.T) { serde.MaybeFail("deserialization", err, serde.Expect(newobj.(*test.Author).Name, obj.Name)) } +func TestProtobufSerdePayloadEncryption(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + serConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-encrypt", + Kind: "TRANSFORM", + Mode: "WRITEREAD", + Type: "ENCRYPT_PAYLOAD", + Params: map[string]string{ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey", + }, + OnFailure: "ERROR,NONE", + } + ruleSet := schemaregistry.RuleSet{ + EncodingRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: authorSchema, + SchemaType: "PROTOBUF", + RuleSet: &ruleSet, + } + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + obj := test.Author{ + Name: "Kafka", + Id: 123, + Picture: []byte{1, 2}, + Works: []string{"The Castle", "The Trial"}, + PiiOneof: &test.Author_OneofString{OneofString: "oneof"}, + } + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + deserConfig := NewDeserializerConfig() + deserConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + + err = deser.ProtoRegistry.RegisterMessage(obj.ProtoReflect().Type()) + serde.MaybeFail("register message", err) + + newobj, err := deser.Deserialize("topic1", bytes) + serde.MaybeFail("deserialization", err, serde.Expect(newobj.(*test.Author).Name, obj.Name)) + + err = deser.DeserializeInto("topic1", bytes, newobj) + serde.MaybeFail("deserialization", err, serde.Expect(newobj.(*test.Author).Name, obj.Name)) +} + func TestProtobufSerdeJSONataFullyCompatible(t *testing.T) { serde.MaybeFail = serde.InitFailFunc(t) var err error diff --git a/schemaregistry/serde/serde.go b/schemaregistry/serde/serde.go index a595a820b..36ed1c640 100644 --- a/schemaregistry/serde/serde.go +++ b/schemaregistry/serde/serde.go @@ -761,7 +761,7 @@ func (s *Serde) GetMigrations(subject string, topic string, sourceInfo *schemare previous = version continue } - if version.RuleSet != nil && version.RuleSet.HasRules(migrationMode) { + if version.RuleSet != nil && version.RuleSet.HasRules(schemaregistry.MigrationPhase, migrationMode) { var m Migration if migrationMode == schemaregistry.Upgrade { m = Migration{ @@ -819,7 +819,8 @@ type Migration struct { func (s *Serde) ExecuteMigrations(migrations []Migration, subject string, topic string, msg interface{}) (interface{}, error) { var err error for _, migration := range migrations { - msg, err = s.ExecuteRules(subject, topic, migration.RuleMode, + msg, err = s.ExecuteRulesWithPhase(subject, topic, + schemaregistry.MigrationPhase, migration.RuleMode, &migration.Source.SchemaInfo, &migration.Target.SchemaInfo, msg) if err != nil { return nil, err @@ -830,6 +831,13 @@ func (s *Serde) ExecuteMigrations(migrations []Migration, subject string, topic // ExecuteRules executes the given rules func (s *Serde) ExecuteRules(subject string, topic string, ruleMode schemaregistry.RuleMode, + source *schemaregistry.SchemaInfo, target *schemaregistry.SchemaInfo, msg interface{}) (interface{}, error) { + return s.ExecuteRulesWithPhase(subject, topic, schemaregistry.DomainPhase, ruleMode, source, target, msg) +} + +// ExecuteRulesWithPhase executes the given rules +func (s *Serde) ExecuteRulesWithPhase(subject string, topic string, + rulePhase schemaregistry.RulePhase, ruleMode schemaregistry.RuleMode, source *schemaregistry.SchemaInfo, target *schemaregistry.SchemaInfo, msg interface{}) (interface{}, error) { if msg == nil || target == nil { return msg, nil @@ -847,7 +855,11 @@ func (s *Serde) ExecuteRules(subject string, topic string, ruleMode schemaregist } default: if target.RuleSet != nil { - rules = target.RuleSet.DomainRules + if rulePhase == schemaregistry.EncodingPhase { + rules = target.RuleSet.EncodingRules + } else { + rules = target.RuleSet.DomainRules + } if ruleMode == schemaregistry.Read { // Execute read rules in reverse order for symmetry rules = reverseRules(rules)