Skip to content

Commit 8ebf8b8

Browse files
Merge pull request #261 from KernelTuner/update-pmt
Update PMTObserver for latest PMT changes
2 parents 99b5c90 + db9fc45 commit 8ebf8b8

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

examples/cuda/vector_add_observers_pmt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def tune():
3131
tune_params = dict()
3232
tune_params["block_size_x"] = [128+64*i for i in range(15)]
3333

34-
pmtobserver = PMTObserver(["nvml", "rapl"])
34+
pmtobserver = PMTObserver([("nvidia", 0), "rapl"])
3535

3636
metrics = OrderedDict()
37-
metrics["GPU W"] = lambda p: p["nvml_power"]
37+
metrics["GPU W"] = lambda p: p["nvidia_power"]
3838
metrics["CPU W"] = lambda p: p["rapl_power"]
3939

4040
results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, observers=[pmtobserver], metrics=metrics, iterations=32)

kernel_tuner/observers/pmt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ def __init__(self, observable=None):
3838
if type(observable) is dict:
3939
pass
4040
elif type(observable) is list:
41-
# user specifies a list of platforms as observable
42-
observable = dict([(obs, 0) for obs in observable])
41+
# user specifies a list of platforms as observable, optionally with an argument
42+
observable = dict([obs if isinstance(obs, tuple) else (obs, None) for obs in observable])
4343
else:
4444
# User specifices a string (single platform) as observable
4545
observable = {observable: None}
46-
supported = ["arduino", "jetson", "likwid", "nvml", "rapl", "rocm", "xilinx"]
46+
supported = ["powersensor2", "powersensor3", "nvidia", "likwid", "rapl", "rocm", "xilinx"]
4747
for obs in observable.keys():
4848
if not obs in supported:
4949
raise ValueError(f"Observable {obs} not in supported: {supported}")
5050

51-
self.pms = [pmt.get_pmt(obs[0], obs[1]) for obs in observable.items()]
51+
self.pms = [pmt.create(obs[0], obs[1]) for obs in observable.items()]
5252
self.pm_names = list(observable.keys())
5353

5454
self.begin_states = [None] * len(self.pms)

0 commit comments

Comments
 (0)