Skip to content

Commit cb0e8d3

Browse files
committed
[Python] Add new methods into matrix class (full mat C API support)
1 parent 65d74a1 commit cb0e8d3

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

python/pycubool/bridge.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,17 @@ def get_reduce_hints(time_check):
9999
return hints
100100

101101

102+
def get_reduce_vector_hints(transpose, time_check):
103+
hints = _hint_no
104+
105+
if transpose:
106+
hints |= _hint_transpose
107+
if time_check:
108+
hints |= _hint_time_check
109+
110+
return hints
111+
112+
102113
def get_kronecker_hints(time_check):
103114
hints = _hint_no
104115

@@ -128,6 +139,15 @@ def get_vxm_hints(time_check):
128139
return hints
129140

130141

142+
def get_mxv_hints(time_check):
143+
hints = _hint_no
144+
145+
if time_check:
146+
hints |= _hint_time_check
147+
148+
return hints
149+
150+
131151
def get_ewiseadd_hints(time_check):
132152
hints = _hint_no
133153

python/pycubool/matrix.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from . import wrapper
99
from . import bridge
10+
from . import vector
1011

1112

1213
__all__ = [
@@ -490,6 +491,34 @@ def extract_matrix(self, i, j, shape, out=None, time_check=False):
490491
bridge.check(status)
491492
return out
492493

494+
def extract_row(self, i, out=None):
495+
if out is None:
496+
out = Vector.empty(self.ncols)
497+
498+
status = wrapper.loaded_dll.cuBool_Matrix_ExtractRow(
499+
out.hnd,
500+
self.hnd,
501+
ctypes.c_uint(i),
502+
ctypes.c_uint(0)
503+
)
504+
505+
bridge.check(status)
506+
return out
507+
508+
def extract_col(self, j, out=None):
509+
if out is None:
510+
out = Vector.empty(self.nrows)
511+
512+
status = wrapper.loaded_dll.cuBool_Matrix_ExtractCol(
513+
out.hnd,
514+
self.hnd,
515+
ctypes.c_uint(j),
516+
ctypes.c_uint(0)
517+
)
518+
519+
bridge.check(status)
520+
return out
521+
493522
def mxm(self, other, out=None, accumulate=False, time_check=False):
494523
"""
495524
Matrix-matrix multiplication in boolean semiring with "x = and" and "+ = or" operations.
@@ -532,6 +561,20 @@ def mxm(self, other, out=None, accumulate=False, time_check=False):
532561
bridge.check(status)
533562
return out
534563

564+
def mxv(self, other, out=None, time_check=False):
565+
if out is None:
566+
out = Vector.empty(self.nrows)
567+
568+
status = wrapper.loaded_dll.cuBool_MxV(
569+
out.hnd,
570+
self.hnd,
571+
other.hnd,
572+
ctypes.c_uint(bridge.get_mxv_hints(time_check=time_check))
573+
)
574+
575+
bridge.check(status)
576+
return out
577+
535578
def kronecker(self, other, out=None, time_check=False):
536579
"""
537580
Matrix-matrix kronecker product with boolean "x = and" operation.
@@ -643,6 +686,20 @@ def reduce(self, time_check=False):
643686
bridge.check(status)
644687
return out
645688

689+
def reduce_vector(self, out=None, transpose=False, time_check=False):
690+
if out is None:
691+
nrows = self.ncols if transpose else self.nrows
692+
out = Vector.empty(nrows)
693+
694+
status = wrapper.loaded_dll.cuBool_Matrix_Reduce(
695+
out.hnd,
696+
self.hnd,
697+
ctypes.c_uint(bridge.get_reduce_vector_hints(transpose=transpose, time_check=time_check))
698+
)
699+
700+
bridge.check(status)
701+
return out
702+
646703
def equals(self, other) -> bool:
647704
"""
648705
Compare two matrices. Returns true if they are equal.

0 commit comments

Comments
 (0)