File tree Expand file tree Collapse file tree 3 files changed +5
-2
lines changed
nle_language_wrapper/agents/sample_factory Expand file tree Collapse file tree 3 files changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -44,7 +44,7 @@ def _tokenize(self, str_obsv):
44
44
max_length = self .cfg ["max_token_length" ],
45
45
)
46
46
# Sample factory insists on normalizing obs key.
47
- tokens .data ["obs" ] = torch .tensor ( 0 )
47
+ tokens .data ["obs" ] = torch .zeros ( 1 )
48
48
return tokens .data
49
49
50
50
def _convert_obsv_to_str (self , obsv ):
Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ known_third_party = ["nle"]
17
17
18
18
[tool .pylint .messages_control ]
19
19
max-line-length = 88
20
+ generated-members =" torch.*"
20
21
disable = [
21
22
" missing-function-docstring" ,
22
23
" missing-module-docstring" ,
Original file line number Diff line number Diff line change @@ -62,6 +62,7 @@ def build_extension(self, ext):
62
62
"dev" : [
63
63
"black>=22.6.0" ,
64
64
"flake8>=4.0.1" ,
65
+ "pylint>=2.15.8" ,
65
66
"pytest>=7.1.2" ,
66
67
"pytest-cov>=3.0.0" ,
67
68
"pytest-mock>=3.7.0" ,
@@ -72,7 +73,8 @@ def build_extension(self, ext):
72
73
"agent" : [
73
74
"sample_factory>=1.121.4" ,
74
75
"transformers>=4.17.0" ,
75
- "torch@https://download.pytorch.org/whl/cu111/torch-1.9.1%2Bcu111-cp39-cp39-linux_x86_64.whl" ,
76
+ "torch@https://download.pytorch.org/whl/cu111/"
77
+ "torch-1.9.1%2Bcu111-cp39-cp39-linux_x86_64.whl" ,
76
78
],
77
79
}
78
80
You can’t perform that action at this time.
0 commit comments