Skip to content

Commit 60da654

Browse files
committed
Fix torch imports for rpi (#1004)
* Moving torch imports into functions such that donkeycar can continue to work w/o pytorch installation. This is the default setup on RPi. * Bumped version (cherry picked from commit 14db8a8)
1 parent 2a6e12d commit 60da654

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

donkeycar/parts/interpreter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Union, Sequence, List
66

77
import tensorflow as tf
8-
import torch
98
from tensorflow import keras
109

1110
from tensorflow.python.framework.convert_to_constants import \
@@ -171,6 +170,7 @@ def load_weights(self, model_path: str, by_name: bool = True) -> \
171170
def summary(self) -> str:
172171
return self.model.summary()
173172

173+
174174
class FastAIInterpreter(Interpreter):
175175

176176
def __init__(self):
@@ -206,14 +206,15 @@ def invoke(self, inputs):
206206

207207
def predict(self, img_arr: np.ndarray, other_arr: np.ndarray) \
208208
-> Sequence[Union[float, np.ndarray]]:
209-
209+
import torch
210210
inputs = torch.unsqueeze(img_arr, 0)
211211
if other_arr is not None:
212212
#other_arr = np.expand_dims(other_arr, axis=0)
213213
inputs = [img_arr, other_arr]
214214
return self.invoke(inputs)
215215

216216
def load(self, model_path: str) -> None:
217+
import torch
217218
logger.info(f'Loading model {model_path}')
218219
if torch.cuda.is_available():
219220
logger.info("using cuda for torch inference")
@@ -228,6 +229,7 @@ def load(self, model_path: str) -> None:
228229
def summary(self) -> str:
229230
return self.model
230231

232+
231233
class TfLite(Interpreter):
232234
"""
233235
This class wraps around the TensorFlow Lite interpreter.

donkeycar/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,6 @@ def get_model_by_type(model_type: str, cfg: 'Config') -> Union['KerasPilot', 'Fa
439439
from donkeycar.parts.interpreter import KerasInterpreter, TfLite, TensorRT, \
440440
FastAIInterpreter
441441

442-
from donkeycar.parts.fastai import FastAILinear
443-
444442
if model_type is None:
445443
model_type = cfg.DEFAULT_MODEL_TYPE
446444
logger.info(f'get_model_by_type: model type is: {model_type}')
@@ -455,6 +453,7 @@ def get_model_by_type(model_type: str, cfg: 'Config') -> Union['KerasPilot', 'Fa
455453
interpreter = FastAIInterpreter()
456454
used_model_type = model_type.replace('fastai_', '')
457455
if used_model_type == "linear":
456+
from donkeycar.parts.fastai import FastAILinear
458457
return FastAILinear(interpreter=interpreter, input_shape=input_shape)
459458
else:
460459
interpreter = KerasInterpreter()

0 commit comments

Comments
 (0)