Skip to content

Commit f09dd04

Browse files
landritotheacodes
authored andcommitted
Add protobuf_helpers.field_mask to calculate a field mask from two messages (#5320)
1 parent cf94432 commit f09dd04

File tree

2 files changed

+334
-9
lines changed

2 files changed

+334
-9
lines changed

google/api_core/protobuf_helpers.py

Lines changed: 105 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,25 @@
1515
"""Helpers for :mod:`protobuf`."""
1616

1717
import collections
18+
import copy
1819
import inspect
1920

20-
from google.protobuf.message import Message
21+
from google.protobuf import field_mask_pb2
22+
from google.protobuf import message
23+
from google.protobuf import wrappers_pb2
2124

2225
_SENTINEL = object()
26+
_WRAPPER_TYPES = (
27+
wrappers_pb2.BoolValue,
28+
wrappers_pb2.BytesValue,
29+
wrappers_pb2.DoubleValue,
30+
wrappers_pb2.FloatValue,
31+
wrappers_pb2.Int32Value,
32+
wrappers_pb2.Int64Value,
33+
wrappers_pb2.StringValue,
34+
wrappers_pb2.UInt32Value,
35+
wrappers_pb2.UInt64Value,
36+
)
2337

2438

2539
def from_any_pb(pb_type, any_pb):
@@ -73,13 +87,15 @@ def get_messages(module):
7387
module to find Message subclasses.
7488
7589
Returns:
76-
dict[str, Message]: A dictionary with the Message class names as
77-
keys, and the Message subclasses themselves as values.
90+
dict[str, google.protobuf.message.Message]: A dictionary with the
91+
Message class names as keys, and the Message subclasses themselves
92+
as values.
7893
"""
7994
answer = collections.OrderedDict()
8095
for name in dir(module):
8196
candidate = getattr(module, name)
82-
if inspect.isclass(candidate) and issubclass(candidate, Message):
97+
if (inspect.isclass(candidate) and
98+
issubclass(candidate, message.Message)):
8399
answer[name] = candidate
84100
return answer
85101

@@ -143,7 +159,7 @@ def get(msg_or_dict, key, default=_SENTINEL):
143159

144160
# Attempt to get the value from the two types of objects we know about.
145161
# If we get something else, complain.
146-
if isinstance(msg_or_dict, Message):
162+
if isinstance(msg_or_dict, message.Message):
147163
answer = getattr(msg_or_dict, key, default)
148164
elif isinstance(msg_or_dict, collections.Mapping):
149165
answer = msg_or_dict.get(key, default)
@@ -186,7 +202,7 @@ def _set_field_on_message(msg, key, value):
186202
# Assign the dictionary values to the protobuf message.
187203
for item_key, item_value in value.items():
188204
set(getattr(msg, key), item_key, item_value)
189-
elif isinstance(value, Message):
205+
elif isinstance(value, message.Message):
190206
getattr(msg, key).CopyFrom(value)
191207
else:
192208
setattr(msg, key, value)
@@ -205,7 +221,8 @@ def set(msg_or_dict, key, value):
205221
TypeError: If ``msg_or_dict`` is not a Message or dictionary.
206222
"""
207223
# Sanity check: Is our target object valid?
208-
if not isinstance(msg_or_dict, (collections.MutableMapping, Message)):
224+
if (not isinstance(msg_or_dict,
225+
(collections.MutableMapping, message.Message))):
209226
raise TypeError(
210227
'set() expected a dict or protobuf message, got {!r}.'.format(
211228
type(msg_or_dict)))
@@ -247,3 +264,84 @@ def setdefault(msg_or_dict, key, value):
247264
"""
248265
if not get(msg_or_dict, key, default=None):
249266
set(msg_or_dict, key, value)
267+
268+
269+
def field_mask(original, modified):
270+
"""Create a field mask by comparing two messages.
271+
272+
Args:
273+
original (~google.protobuf.message.Message): the original message.
274+
If set to None, this field will be interpretted as an empty
275+
message.
276+
modified (~google.protobuf.message.Message): the modified message.
277+
If set to None, this field will be interpretted as an empty
278+
message.
279+
280+
Returns:
281+
google.protobuf.field_mask_pb2.FieldMask: field mask that contains
282+
the list of field names that have different values between the two
283+
messages. If the messages are equivalent, then the field mask is empty.
284+
285+
Raises:
286+
ValueError: If the ``original`` or ``modified`` are not the same type.
287+
"""
288+
if original is None and modified is None:
289+
return field_mask_pb2.FieldMask()
290+
291+
if original is None and modified is not None:
292+
original = copy.deepcopy(modified)
293+
original.Clear()
294+
295+
if modified is None and original is not None:
296+
modified = copy.deepcopy(original)
297+
modified.Clear()
298+
299+
if type(original) != type(modified):
300+
raise ValueError(
301+
'expected that both original and modified should be of the '
302+
'same type, received "{!r}" and "{!r}".'.
303+
format(type(original), type(modified)))
304+
305+
return field_mask_pb2.FieldMask(
306+
paths=_field_mask_helper(original, modified))
307+
308+
309+
def _field_mask_helper(original, modified, current=''):
310+
answer = []
311+
312+
for name in original.DESCRIPTOR.fields_by_name:
313+
field_path = _get_path(current, name)
314+
315+
original_val = getattr(original, name)
316+
modified_val = getattr(modified, name)
317+
318+
if _is_message(original_val) or _is_message(modified_val):
319+
if original_val != modified_val:
320+
# Wrapper types do not need to include the .value part of the
321+
# path.
322+
if _is_wrapper(original_val) or _is_wrapper(modified_val):
323+
answer.append(field_path)
324+
elif not modified_val.ListFields():
325+
answer.append(field_path)
326+
else:
327+
answer.extend(_field_mask_helper(original_val,
328+
modified_val, field_path))
329+
else:
330+
if original_val != modified_val:
331+
answer.append(field_path)
332+
333+
return answer
334+
335+
336+
def _get_path(current, name):
337+
if not current:
338+
return name
339+
return '%s.%s' % (current, name)
340+
341+
342+
def _is_message(value):
343+
return isinstance(value, message.Message)
344+
345+
346+
def _is_wrapper(value):
347+
return type(value) in _WRAPPER_TYPES

0 commit comments

Comments
 (0)