Skip to content

Commit 2ec9417

Browse files
formatted with black
1 parent caff3dc commit 2ec9417

File tree

1 file changed

+21
-58
lines changed

1 file changed

+21
-58
lines changed

kernel_tuner/observers/nvml.py

Lines changed: 21 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
class nvml:
1616
"""Class that gathers the NVML functionality for one device."""
1717

18-
def __init__(
19-
self, device_id=0, nvidia_smi_fallback="nvidia-smi", use_locked_clocks=False
20-
):
18+
def __init__(self, device_id=0, nvidia_smi_fallback="nvidia-smi", use_locked_clocks=False):
2119
"""Create object to control device using NVML."""
2220
pynvml.nvmlInit()
2321
self.dev = pynvml.nvmlDeviceGetHandleByIndex(device_id)
@@ -26,9 +24,7 @@ def __init__(
2624

2725
try:
2826
self.pwr_limit_default = pynvml.nvmlDeviceGetPowerManagementLimit(self.dev)
29-
self.pwr_constraints = pynvml.nvmlDeviceGetPowerManagementLimitConstraints(
30-
self.dev
31-
)
27+
self.pwr_constraints = pynvml.nvmlDeviceGetPowerManagementLimitConstraints(self.dev)
3228
except pynvml.NVMLError_NotSupported:
3329
self.pwr_limit_default = None
3430
# inverted range to make all range checks fail
@@ -52,9 +48,7 @@ def __init__(
5248
self.gr_clock_default = pynvml.nvmlDeviceGetDefaultApplicationsClock(
5349
self.dev, pynvml.NVML_CLOCK_GRAPHICS
5450
)
55-
self.mem_clock_default = pynvml.nvmlDeviceGetDefaultApplicationsClock(
56-
self.dev, pynvml.NVML_CLOCK_MEM
57-
)
51+
self.mem_clock_default = pynvml.nvmlDeviceGetDefaultApplicationsClock(self.dev, pynvml.NVML_CLOCK_MEM)
5852
except pynvml.NVMLError_NotSupported:
5953
self.gr_clock_default = None
6054
self.sm_clock_default = None
@@ -67,9 +61,7 @@ def __init__(
6761
# gather the supported gr clocks for each supported mem clock into a dict
6862
self.supported_gr_clocks = {}
6963
for mem_clock in self.supported_mem_clocks:
70-
supported_gr_clocks = pynvml.nvmlDeviceGetSupportedGraphicsClocks(
71-
self.dev, mem_clock
72-
)
64+
supported_gr_clocks = pynvml.nvmlDeviceGetSupportedGraphicsClocks(self.dev, mem_clock)
7365
self.supported_gr_clocks[mem_clock] = supported_gr_clocks
7466

7567
# test whether locked gr clocks and mem clocks are supported
@@ -132,9 +124,7 @@ def persistence_mode(self):
132124
@persistence_mode.setter
133125
def persistence_mode(self, new_mode):
134126
if new_mode not in [0, 1]:
135-
raise ValueError(
136-
"Illegal value for persistence mode, should be either 0 or 1"
137-
)
127+
raise ValueError("Illegal value for persistence mode, should be either 0 or 1")
138128
if self.persistence_mode == new_mode:
139129
return
140130
try:
@@ -155,7 +145,9 @@ def set_clocks(self, mem_clock, gr_clock):
155145
if mem_clock not in self.supported_mem_clocks:
156146
raise ValueError("Illegal value for memory clock")
157147
if gr_clock not in self.supported_gr_clocks[mem_clock]:
158-
raise ValueError(f"Graphics clock incompatible with memory clock ({mem_clock}), compatible graphics clocks: {self.supported_gr_clocks[mem_clock]}")
148+
raise ValueError(
149+
f"Graphics clock incompatible with memory clock ({mem_clock}), compatible graphics clocks: {self.supported_gr_clocks[mem_clock]}"
150+
)
159151

160152
# Check whether persistence mode is set. Without persistence mode, setting the clocks is not meaningful
161153
# I deliberately removed the try..except clause here, if setting persistence mode fails, setting the clocks should fail
@@ -185,7 +177,6 @@ def set_clocks(self, mem_clock, gr_clock):
185177
# Store the fact that we have modified the clocks
186178
self.modified_clocks = True
187179

188-
189180
def reset_clocks(self):
190181
"""Reset the clocks to the default clock if the device uses a non default clock."""
191182
if self.use_locked_clocks:
@@ -212,16 +203,9 @@ def reset_clocks(self):
212203
subprocess.run(args, check=True)
213204

214205
elif self.gr_clock_default is not None:
215-
gr_app_clock = pynvml.nvmlDeviceGetApplicationsClock(
216-
self.dev, pynvml.NVML_CLOCK_GRAPHICS
217-
)
218-
mem_app_clock = pynvml.nvmlDeviceGetApplicationsClock(
219-
self.dev, pynvml.NVML_CLOCK_MEM
220-
)
221-
if (
222-
gr_app_clock != self.gr_clock_default
223-
or mem_app_clock != self.mem_clock_default
224-
):
206+
gr_app_clock = pynvml.nvmlDeviceGetApplicationsClock(self.dev, pynvml.NVML_CLOCK_GRAPHICS)
207+
mem_app_clock = pynvml.nvmlDeviceGetApplicationsClock(self.dev, pynvml.NVML_CLOCK_MEM)
208+
if gr_app_clock != self.gr_clock_default or mem_app_clock != self.mem_clock_default:
225209
self.set_clocks(self.mem_clock_default, self.gr_clock_default)
226210

227211
@property
@@ -246,9 +230,7 @@ def mem_clock(self):
246230
mem_clock = pynvml.nvmlDeviceGetClockInfo(self.dev, pynvml.NVML_CLOCK_MEM)
247231
return min(self.supported_mem_clocks, key=lambda x: abs(x - mem_clock))
248232
else:
249-
return pynvml.nvmlDeviceGetApplicationsClock(
250-
self.dev, pynvml.NVML_CLOCK_MEM
251-
)
233+
return pynvml.nvmlDeviceGetApplicationsClock(self.dev, pynvml.NVML_CLOCK_MEM)
252234

253235
@mem_clock.setter
254236
def mem_clock(self, new_clock):
@@ -269,9 +251,7 @@ def auto_boost(self):
269251
def auto_boost(self, setting):
270252
# might need to use pynvml.NVML_FEATURE_DISABLED or pynvml.NVML_FEATURE_ENABLED instead of 0 or 1
271253
if setting not in [0, 1]:
272-
raise ValueError(
273-
"Illegal value for auto boost enabled, should be either 0 or 1"
274-
)
254+
raise ValueError("Illegal value for auto boost enabled, should be either 0 or 1")
275255
pynvml.nvmlDeviceSetAutoBoostedClocksEnabled(self.dev, setting)
276256
self._auto_boost = pynvml.nvmlDeviceGetAutoBoostedClocksEnabled(self.dev)[0]
277257

@@ -363,9 +343,7 @@ def __init__(
363343
if any([obs in self.needs_power for obs in observables]):
364344
self.measure_power = True
365345
power_observables = [obs for obs in observables if obs in self.needs_power]
366-
self.continuous_observer = NVMLPowerObserver(
367-
power_observables, self, self.nvml, continous_duration
368-
)
346+
self.continuous_observer = NVMLPowerObserver(power_observables, self, self.nvml, continous_duration)
369347

370348
# remove power observables
371349
self.observables = [obs for obs in observables if obs not in self.needs_power]
@@ -380,11 +358,7 @@ def __init__(
380358
for obs in self.observables:
381359
self.results[obs + "s"] = []
382360

383-
self.during_obs = [
384-
obs
385-
for obs in observables
386-
if obs in ["core_freq", "mem_freq", "temperature"]
387-
]
361+
self.during_obs = [obs for obs in observables if obs in ["core_freq", "mem_freq", "temperature"]]
388362
self.iteration = {obs: [] for obs in self.during_obs}
389363

390364
def before_start(self):
@@ -406,15 +380,11 @@ def during(self):
406380
if "mem_freq" in self.observables:
407381
self.iteration["mem_freq"].append(self.nvml.mem_clock)
408382
if self.record_gr_voltage:
409-
self.gr_voltage_readings.append(
410-
[time.perf_counter() - self.t0, self.nvml.gr_voltage()]
411-
)
383+
self.gr_voltage_readings.append([time.perf_counter() - self.t0, self.nvml.gr_voltage()])
412384

413385
def after_finish(self):
414386
if "temperature" in self.observables:
415-
self.results["temperatures"].append(
416-
np.average(self.iteration["temperature"])
417-
)
387+
self.results["temperatures"].append(np.average(self.iteration["temperature"]))
418388
if "core_freq" in self.observables:
419389
self.results["core_freqs"].append(np.average(self.iteration["core_freq"]))
420390
if "mem_freq" in self.observables:
@@ -423,12 +393,8 @@ def after_finish(self):
423393
if "gr_voltage" in self.observables:
424394
execution_time = time.time() - self.t0
425395
gr_voltage_readings = self.gr_voltage_readings
426-
gr_voltage_readings = [
427-
[0.0, gr_voltage_readings[0][1]]
428-
] + gr_voltage_readings
429-
gr_voltage_readings = gr_voltage_readings + [
430-
[execution_time, gr_voltage_readings[-1][1]]
431-
]
396+
gr_voltage_readings = [[0.0, gr_voltage_readings[0][1]]] + gr_voltage_readings
397+
gr_voltage_readings = gr_voltage_readings + [[execution_time, gr_voltage_readings[-1][1]]]
432398
# time in s, graphics voltage in millivolts
433399
self.results["gr_voltages"].append(np.average(gr_voltage_readings[:][:][1]))
434400

@@ -490,8 +456,7 @@ def during(self):
490456
timestamp = time.perf_counter() - self.t0
491457
# only store the result if we get a new measurement from NVML
492458
if len(self.power_readings) == 0 or (
493-
self.power_readings[-1][1] != power_usage
494-
or timestamp - self.power_readings[-1][0] > 0.01
459+
self.power_readings[-1][1] != power_usage or timestamp - self.power_readings[-1][0] > 0.01
495460
):
496461
self.power_readings.append([timestamp, power_usage])
497462

@@ -538,9 +503,7 @@ def get_nvml_pwr_limits(device, n=None, quiet=False):
538503
n = int((power_limit_max - power_limit_min) / power_limit_round) + 1
539504

540505
# Rounded power limit values
541-
power_limits = power_limit_round * np.round(
542-
(np.linspace(power_limit_min, power_limit_max, n) / power_limit_round)
543-
)
506+
power_limits = power_limit_round * np.round((np.linspace(power_limit_min, power_limit_max, n) / power_limit_round))
544507
power_limits = sorted(list(set([int(power_limit) for power_limit in power_limits])))
545508
tune_params["nvml_pwr_limit"] = power_limits
546509

0 commit comments

Comments
 (0)