Skip to content

Commit 596ef3d

Browse files
committed
Merge branch 'master' of github.com:petercorke/spatialmath-python
2 parents 3e8a86c + 43cef40 commit 596ef3d

File tree

2 files changed

+67
-40
lines changed

2 files changed

+67
-40
lines changed

spatialmath/base/argcheck.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# pylint: disable=invalid-name
1313

1414
import math
15+
from typing import Union
1516
import numpy as np
1617
from spatialmath.base import symbolic as sym
1718

1819
# valid scalar types
1920
_scalartypes = (int, np.integer, float, np.floating) + sym.symtype
2021

22+
ArrayLike = Union[list, np.ndarray, tuple, set]
2123

2224
def isscalar(x):
2325
"""
@@ -256,7 +258,7 @@ def verifymatrix(m, shape):
256258
# and not np.iscomplex(m) checks every element, would need to be not np.any(np.iscomplex(m)) which seems expensive
257259

258260

259-
def getvector(v, dim=None, out="array", dtype=np.float64):
261+
def getvector(v, dim=None, out="array", dtype=np.float64) -> ArrayLike:
260262
"""
261263
Return a vector value
262264
@@ -451,7 +453,7 @@ def isvector(v, dim=None):
451453
return False
452454

453455

454-
def getunit(v, unit="rad"):
456+
def getunit(v, unit="rad") -> ArrayLike:
455457
"""
456458
Convert value according to angular units
457459

spatialmath/baseposelist.py

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@
1212

1313
_numtypes = (int, np.int64, float, np.float64)
1414

15+
1516
class BasePoseList(UserList, ABC):
1617
"""
1718
List properties for spatial math classes
1819
1920
Each of the spatial math classes behaves like a regular Python object and
20-
an instance contains a value of a particular type, for example an SE(3)
21+
an instance contains a value of a particular type, for example an SE(3)
2122
matrix, a unit quaternion, a twist etc.
2223
2324
This class adds list-like capabilities to each of spatial math classes. This
24-
means that an instance is not limited to holding just a single value (a
25-
singleton instance), it can hold a list of values. That list can contain
25+
means that an instance is not limited to holding just a single value (a
26+
singleton instance), it can hold a list of values. That list can contain
2627
zero or more items. This is helpful for:
27-
28+
2829
- storing sequences (trajectories) where it is important to know that all
2930
elements in the sequence are of the same time and have valid values
3031
- arrays of the same type to enable C++ like programming patterns
@@ -86,7 +87,7 @@ def _import(self, x, check=True):
8687
def Empty(cls):
8788
"""
8889
Construct an empty instance (BasePoseList superclass method)
89-
90+
9091
:return: pose instance with zero values
9192
9293
Example::
@@ -117,7 +118,7 @@ def Alloc(cls, n=1):
117118
can be referenced ``X[i]`` or assigned to ``X[i] = ...``.
118119
119120
.. note:: The default value depends on the pose class and is the result
120-
of the empty constructor. For ``SO2``,
121+
of the empty constructor. For ``SO2``,
121122
``SE2``, ``SO3``, ``SE3`` it is an identity matrix, for a
122123
twist class ``Twist2`` or ``Twist3`` it is a zero vector,
123124
for a ``UnitQuaternion`` or ``Quaternion`` it is a zero
@@ -195,10 +196,16 @@ def arghandler(self, arg, convertfrom=(), check=True):
195196

196197
elif type(arg[0]) == type(self):
197198
# possibly a list of objects of same type
198-
assert all(map(lambda x: type(x) == type(self), arg)), 'elements of list are incorrect type'
199+
assert all(
200+
map(lambda x: type(x) == type(self), arg)
201+
), "elements of list are incorrect type"
199202
self.data = [x.A for x in arg]
200203

201-
elif argcheck.isnumberlist(arg) and len(self.shape) == 1 and len(arg) == self.shape[0]:
204+
elif (
205+
argcheck.isnumberlist(arg)
206+
and len(self.shape) == 1
207+
and len(arg) == self.shape[0]
208+
):
202209
self.data = [np.array(arg)]
203210

204211
else:
@@ -215,7 +222,9 @@ def arghandler(self, arg, convertfrom=(), check=True):
215222
# get method to convert from arg to self types
216223
converter = getattr(arg.__class__, type(self).__name__)
217224
except AttributeError:
218-
raise ValueError('argument has no conversion method to this type') from None
225+
raise ValueError(
226+
"argument has no conversion method to this type"
227+
) from None
219228
self.data = [converter(arg).A]
220229

221230
else:
@@ -224,6 +233,15 @@ def arghandler(self, arg, convertfrom=(), check=True):
224233

225234
return True
226235

236+
@property
237+
def __array_interface__(self):
238+
"""
239+
Copies the numpy array interface from the first numpy array
240+
so that C extenstions with this spatial math class have direct
241+
access to the underlying numpy array
242+
"""
243+
return self.data[0].__array_interface__
244+
227245
@property
228246
def _A(self):
229247
"""
@@ -238,18 +256,18 @@ def _A(self):
238256
return self.data
239257

240258
@property
241-
def A(self):
259+
def A(self) -> np.ndarray:
242260
"""
243261
Array value of an instance (BasePoseList superclass method)
244262
245263
:return: NumPy array value of this instance
246264
:rtype: ndarray
247265
248-
- ``X.A`` is a NumPy array that represents the value of this instance,
266+
- ``X.A`` is a NumPy array that represents the value of this instance,
249267
and has a shape given by ``X.shape``.
250268
251269
.. note:: This assumes that ``len(X)`` == 1, ie. it is a single-valued
252-
instance.
270+
instance.
253271
"""
254272

255273
if len(self.data) == 1:
@@ -270,9 +288,9 @@ def __getitem__(self, i):
270288
:raises IndexError: if the element is out of bounds
271289
272290
Note that only a single index is supported, slices are not.
273-
291+
274292
Example::
275-
293+
276294
>>> x = X.Alloc(10)
277295
>>> len(x)
278296
10
@@ -296,14 +314,19 @@ def __getitem__(self, i):
296314
else:
297315
# stop is positive, use it directly
298316
end = i.stop
299-
return self.__class__([self.data[k] for k in range(i.start or 0, end, i.step or 1)])
317+
return self.__class__(
318+
[self.data[k] for k in range(i.start or 0, end, i.step or 1)]
319+
)
300320
else:
301-
return self.__class__(self.data[i], check=False)
302-
321+
ret = self.__class__(self.data[i], check=False)
322+
# ret.__array_interface__ = self.data[i].__array_interface__
323+
return ret
324+
# return self.__class__(self.data[i], check=False)
325+
303326
def __setitem__(self, i, value):
304327
"""
305328
Assign a value to an instance (BasePoseList superclass method)
306-
329+
307330
:param i: index of element to assign to
308331
:type i: int
309332
:param value: the value to insert
@@ -312,7 +335,7 @@ def __setitem__(self, i, value):
312335
313336
Assign the argument to an element of the object's internal list of values.
314337
This supports the assignement operator, for example::
315-
338+
316339
>>> x = X.Alloc(10)
317340
>>> len(x)
318341
10
@@ -324,7 +347,9 @@ def __setitem__(self, i, value):
324347
if not type(self) == type(value):
325348
raise ValueError("can't insert different type of object")
326349
if len(value) > 1:
327-
raise ValueError("can't insert a multivalued element - must have len() == 1")
350+
raise ValueError(
351+
"can't insert a multivalued element - must have len() == 1"
352+
)
328353
self.data[i] = value.A
329354

330355
# flag these binary operators as being not supported
@@ -343,7 +368,7 @@ def __ge__(self, other):
343368
def append(self, item):
344369
"""
345370
Append a value to an instance (BasePoseList superclass method)
346-
371+
347372
:param x: the value to append
348373
:type x: Quaternion or UnitQuaternion instance
349374
:raises ValueError: incorrect type of appended object
@@ -361,18 +386,17 @@ def append(self, item):
361386
362387
where ``X`` is any of the SMTB classes.
363388
"""
364-
#print('in append method')
389+
# print('in append method')
365390
if not type(self) == type(item):
366391
raise ValueError("can't append different type of object")
367392
if len(item) > 1:
368393
raise ValueError("can't append a multivalued instance - use extend")
369394
super().append(item.A)
370-
371395

372396
def extend(self, iterable):
373397
"""
374398
Extend sequence of values in an instance (BasePoseList superclass method)
375-
399+
376400
:param x: the value to extend
377401
:type x: instance of same type
378402
:raises ValueError: incorrect type of appended object
@@ -390,7 +414,7 @@ def extend(self, iterable):
390414
391415
where ``X`` is any of the SMTB classes.
392416
"""
393-
#print('in extend method')
417+
# print('in extend method')
394418
if not type(self) == type(iterable):
395419
raise ValueError("can't append different type of object")
396420
super().extend(iterable._A)
@@ -427,9 +451,11 @@ def insert(self, i, item):
427451
if not type(self) == type(item):
428452
raise ValueError("can't insert different type of object")
429453
if len(item) > 1:
430-
raise ValueError("can't insert a multivalued instance - must have len() == 1")
454+
raise ValueError(
455+
"can't insert a multivalued instance - must have len() == 1"
456+
)
431457
super().insert(i, item._A)
432-
458+
433459
def pop(self, i=-1):
434460
"""
435461
Pop value from an instance (BasePoseList superclass method)
@@ -442,7 +468,7 @@ def pop(self, i=-1):
442468
443469
Removes a value from the value list and returns it. The original
444470
instance is modified.
445-
471+
446472
Example::
447473
448474
>>> x = X.Alloc(10)
@@ -462,7 +488,7 @@ def pop(self, i=-1):
462488
def binop(self, right, op, op2=None, list1=True):
463489
"""
464490
Perform binary operation
465-
491+
466492
:param left: left operand
467493
:type left: BasePoseList subclass
468494
:param right: right operand
@@ -523,7 +549,7 @@ def binop(self, right, op, op2=None, list1=True):
523549

524550
# class * class
525551
if len(left) == 1:
526-
# singleton *
552+
# singleton *
527553
if argcheck.isscalar(right):
528554
if list1:
529555
return [op(left._A, right)]
@@ -539,7 +565,7 @@ def binop(self, right, op, op2=None, list1=True):
539565
# singleton * non-singleton
540566
return [op(left.A, x) for x in right.A]
541567
else:
542-
# non-singleton *
568+
# non-singleton *
543569
if argcheck.isscalar(right):
544570
return [op(x, right) for x in left.A]
545571
elif len(right) == 1:
@@ -549,12 +575,12 @@ def binop(self, right, op, op2=None, list1=True):
549575
# non-singleton * non-singleton
550576
return [op(x, y) for (x, y) in zip(left.A, right.A)]
551577
else:
552-
raise ValueError('length of lists to == must be same length')
578+
raise ValueError("length of lists to == must be same length")
553579

554580
# if isinstance(right, left.__class__):
555581
# # class * class
556582
# if len(left) == 1:
557-
# # singleton *
583+
# # singleton *
558584
# if len(right) == 1:
559585
# # singleton * singleton
560586
# if list1:
@@ -565,7 +591,7 @@ def binop(self, right, op, op2=None, list1=True):
565591
# # singleton * non-singleton
566592
# return [op(left.A, x) for x in right.A]
567593
# else:
568-
# # non-singleton *
594+
# # non-singleton *
569595
# if len(right) == 1:
570596
# # non-singleton * singleton
571597
# return [op(x, right.A) for x in left.A]
@@ -587,7 +613,7 @@ def binop(self, right, op, op2=None, list1=True):
587613
def unop(self, op, matrix=False):
588614
"""
589615
Perform unary operation
590-
616+
591617
:param self: operand
592618
:type self: BasePoseList subclass
593619
:param op: unnary operation
@@ -598,7 +624,7 @@ def unop(self, op, matrix=False):
598624
:rtype: list or NumPy array
599625
600626
The is a helper method for implementing unary operations where the
601-
operand has multiple value. This method computes the value of
627+
operand has multiple value. This method computes the value of
602628
the operation for all input values and returns the result as either
603629
a list or as a matrix which vertically stacks the results.
604630
@@ -613,7 +639,7 @@ def unop(self, op, matrix=False):
613639
========= ==== ===================================
614640
615641
The result is:
616-
642+
617643
- a list of values if ``matrix==False``, or
618644
- a 2D NumPy stack of values if ``matrix==True``, it is assumed
619645
that the value is a 1D array.
@@ -623,4 +649,3 @@ def unop(self, op, matrix=False):
623649
return np.vstack([op(x) for x in self.data])
624650
else:
625651
return [op(x) for x in self.data]
626-

0 commit comments

Comments
 (0)