@@ -29,10 +29,20 @@ type redditWidget struct {
29
29
TopPeriod string `yaml:"top-period"`
30
30
Search string `yaml:"search"`
31
31
ExtraSortBy string `yaml:"extra-sort-by"`
32
- CommentsUrlTemplate string `yaml:"comments-url-template"`
32
+ CommentsURLTemplate string `yaml:"comments-url-template"`
33
33
Limit int `yaml:"limit"`
34
34
CollapseAfter int `yaml:"collapse-after"`
35
- RequestUrlTemplate string `yaml:"request-url-template"`
35
+ RequestURLTemplate string `yaml:"request-url-template"`
36
+
37
+ AppAuth struct {
38
+ Name string `yaml:"name"`
39
+ ID string `yaml:"id"`
40
+ Secret string `yaml:"secret"`
41
+
42
+ enabled bool
43
+ accessToken string
44
+ tokenExpiresAt time.Time
45
+ } `yaml:"app-auth"`
36
46
}
37
47
38
48
func (widget * redditWidget ) initialize () error {
@@ -48,20 +58,30 @@ func (widget *redditWidget) initialize() error {
48
58
widget .CollapseAfter = 5
49
59
}
50
60
51
- if ! isValidRedditSortType (widget .SortBy ) {
61
+ s := widget .SortBy
62
+ if s != "hot" && s != "new" && s != "top" && s != "rising" {
52
63
widget .SortBy = "hot"
53
64
}
54
65
55
- if ! isValidRedditTopPeriod (widget .TopPeriod ) {
66
+ p := widget .TopPeriod
67
+ if p != "hour" && p != "day" && p != "week" && p != "month" && p != "year" && p != "all" {
56
68
widget .TopPeriod = "day"
57
69
}
58
70
59
- if widget .RequestUrlTemplate != "" {
60
- if ! strings .Contains (widget .RequestUrlTemplate , "{REQUEST-URL}" ) {
71
+ if widget .RequestURLTemplate != "" {
72
+ if ! strings .Contains (widget .RequestURLTemplate , "{REQUEST-URL}" ) {
61
73
return errors .New ("no `{REQUEST-URL}` placeholder specified" )
62
74
}
63
75
}
64
76
77
+ a := & widget .AppAuth
78
+ if a .Name != "" || a .ID != "" || a .Secret != "" {
79
+ if a .Name == "" || a .ID == "" || a .Secret == "" {
80
+ return errors .New ("application name, client ID and client secret are required" )
81
+ }
82
+ a .enabled = true
83
+ }
84
+
65
85
widget .
66
86
withTitle ("r/" + widget .Subreddit ).
67
87
withTitleURL ("https://www.reddit.com/r/" + widget .Subreddit + "/" ).
@@ -70,35 +90,8 @@ func (widget *redditWidget) initialize() error {
70
90
return nil
71
91
}
72
92
73
- func isValidRedditSortType (sortBy string ) bool {
74
- return sortBy == "hot" ||
75
- sortBy == "new" ||
76
- sortBy == "top" ||
77
- sortBy == "rising"
78
- }
79
-
80
- func isValidRedditTopPeriod (period string ) bool {
81
- return period == "hour" ||
82
- period == "day" ||
83
- period == "week" ||
84
- period == "month" ||
85
- period == "year" ||
86
- period == "all"
87
- }
88
-
89
93
func (widget * redditWidget ) update (ctx context.Context ) {
90
- // TODO: refactor, use a struct to pass all of these
91
- posts , err := fetchSubredditPosts (
92
- widget .Subreddit ,
93
- widget .SortBy ,
94
- widget .TopPeriod ,
95
- widget .Search ,
96
- widget .CommentsUrlTemplate ,
97
- widget .RequestUrlTemplate ,
98
- widget .Proxy .client ,
99
- widget .ShowFlairs ,
100
- )
101
-
94
+ posts , err := widget .fetchSubredditPosts ()
102
95
if ! widget .canContinueUpdateAfterHandlingErr (err ) {
103
96
return
104
97
}
@@ -155,57 +148,65 @@ type subredditResponseJson struct {
155
148
} `json:"data"`
156
149
}
157
150
158
- func templateRedditCommentsURL ( template , subreddit , postId , postPath string ) string {
159
- template = strings .ReplaceAll (template , "{SUBREDDIT}" , subreddit )
151
+ func ( widget * redditWidget ) parseCustomCommentsURL ( subreddit , postId , postPath string ) string {
152
+ template : = strings .ReplaceAll (widget . CommentsURLTemplate , "{SUBREDDIT}" , subreddit )
160
153
template = strings .ReplaceAll (template , "{POST-ID}" , postId )
161
154
template = strings .ReplaceAll (template , "{POST-PATH}" , strings .TrimLeft (postPath , "/" ))
162
155
163
156
return template
164
157
}
165
158
166
- func fetchSubredditPosts (
167
- subreddit ,
168
- sort ,
169
- topPeriod ,
170
- search ,
171
- commentsUrlTemplate ,
172
- requestUrlTemplate string ,
173
- proxyClient * http.Client ,
174
- showFlairs bool ,
175
- ) (forumPostList , error ) {
159
+ func (widget * redditWidget ) fetchSubredditPosts () (forumPostList , error ) {
160
+ var client requestDoer = defaultHTTPClient
161
+ var baseURL string
162
+ var requestURL string
163
+ var headers http.Header
176
164
query := url.Values {}
177
- var requestUrl string
165
+ app := & widget . AppAuth
178
166
179
- if search != "" {
180
- query .Set ("q" , search + " subreddit:" + subreddit )
181
- query .Set ("sort" , sort )
182
- }
167
+ if ! app .enabled {
168
+ baseURL = "https://www.reddit.com"
169
+ headers = http.Header {
170
+ "User-Agent" : []string {getBrowserUserAgentHeader ()},
171
+ }
172
+ } else {
173
+ baseURL = "https://oauth.reddit.com"
174
+
175
+ if app .accessToken == "" || time .Now ().Add (time .Minute ).After (app .tokenExpiresAt ) {
176
+ if err := widget .fetchNewAppAccessToken (); err != nil {
177
+ return nil , fmt .Errorf ("fetching new app access token: %v" , err )
178
+ }
179
+ }
183
180
184
- if sort == "top" {
185
- query .Set ("t" , topPeriod )
181
+ headers = http.Header {
182
+ "Authorization" : []string {"Bearer " + app .accessToken },
183
+ "User-Agent" : []string {app .Name + "/1.0" },
184
+ }
186
185
}
187
186
188
- if search != "" {
189
- requestUrl = fmt .Sprintf ("https://www.reddit.com/search.json?%s" , query .Encode ())
187
+ if widget .Search != "" {
188
+ query .Set ("q" , widget .Search + " subreddit:" + widget .Subreddit )
189
+ query .Set ("sort" , widget .SortBy )
190
+ requestURL = fmt .Sprintf ("%s/search.json?%s" , baseURL , query .Encode ())
190
191
} else {
191
- requestUrl = fmt .Sprintf ("https://www.reddit.com/r/%s/%s.json?%s" , subreddit , sort , query .Encode ())
192
+ if widget .SortBy == "top" {
193
+ query .Set ("t" , widget .TopPeriod )
194
+ }
195
+ requestURL = fmt .Sprintf ("%s/r/%s/%s.json?%s" , baseURL , widget .Subreddit , widget .SortBy , query .Encode ())
192
196
}
193
197
194
- var client requestDoer = defaultHTTPClient
195
-
196
- if requestUrlTemplate != "" {
197
- requestUrl = strings .ReplaceAll (requestUrlTemplate , "{REQUEST-URL}" , url .QueryEscape (requestUrl ))
198
- } else if proxyClient != nil {
199
- client = proxyClient
198
+ if widget .RequestURLTemplate != "" {
199
+ requestURL = strings .ReplaceAll (widget .RequestURLTemplate , "{REQUEST-URL}" , requestURL )
200
+ } else if widget .Proxy .client != nil {
201
+ client = widget .Proxy .client
200
202
}
201
203
202
- request , err := http .NewRequest ("GET" , requestUrl , nil )
204
+ request , err := http .NewRequest ("GET" , requestURL , nil )
203
205
if err != nil {
204
206
return nil , err
205
207
}
208
+ request .Header = headers
206
209
207
- // Required to increase rate limit, otherwise Reddit randomly returns 429 even after just 2 requests
208
- setBrowserUserAgentHeader (request )
209
210
responseJson , err := decodeJsonFromRequest [subredditResponseJson ](client , request )
210
211
if err != nil {
211
212
return nil , err
@@ -226,10 +227,10 @@ func fetchSubredditPosts(
226
227
227
228
var commentsUrl string
228
229
229
- if commentsUrlTemplate == "" {
230
+ if widget . CommentsURLTemplate == "" {
230
231
commentsUrl = "https://www.reddit.com" + post .Permalink
231
232
} else {
232
- commentsUrl = templateRedditCommentsURL ( commentsUrlTemplate , subreddit , post .Id , post .Permalink )
233
+ commentsUrl = widget . parseCustomCommentsURL ( widget . Subreddit , post .Id , post .Permalink )
233
234
}
234
235
235
236
forumPost := forumPost {
@@ -249,19 +250,18 @@ func fetchSubredditPosts(
249
250
forumPost .TargetUrl = post .Url
250
251
}
251
252
252
- if showFlairs && post .Flair != "" {
253
+ if widget . ShowFlairs && post .Flair != "" {
253
254
forumPost .Tags = append (forumPost .Tags , post .Flair )
254
255
}
255
256
256
257
if len (post .ParentList ) > 0 {
257
258
forumPost .IsCrosspost = true
258
259
forumPost .TargetUrlDomain = "r/" + post .ParentList [0 ].Subreddit
259
260
260
- if commentsUrlTemplate == "" {
261
+ if widget . CommentsURLTemplate == "" {
261
262
forumPost .TargetUrl = "https://www.reddit.com" + post .ParentList [0 ].Permalink
262
263
} else {
263
- forumPost .TargetUrl = templateRedditCommentsURL (
264
- commentsUrlTemplate ,
264
+ forumPost .TargetUrl = widget .parseCustomCommentsURL (
265
265
post .ParentList [0 ].Subreddit ,
266
266
post .ParentList [0 ].Id ,
267
267
post .ParentList [0 ].Permalink ,
@@ -274,3 +274,32 @@ func fetchSubredditPosts(
274
274
275
275
return posts , nil
276
276
}
277
+
278
+ func (widget * redditWidget ) fetchNewAppAccessToken () (err error ) {
279
+ body := strings .NewReader ("grant_type=client_credentials" )
280
+ req , err := http .NewRequest ("POST" , "https://www.reddit.com/api/v1/access_token" , body )
281
+ if err != nil {
282
+ return fmt .Errorf ("creating request for app access token: %v" , err )
283
+ }
284
+
285
+ app := & widget .AppAuth
286
+ req .SetBasicAuth (app .ID , app .Secret )
287
+ req .Header .Add ("User-Agent" , app .Name + "/1.0" )
288
+ req .Header .Add ("Content-Type" , "application/x-www-form-urlencoded" )
289
+
290
+ type tokenResponse struct {
291
+ AccessToken string `json:"access_token"`
292
+ ExpiresIn int `json:"expires_in"`
293
+ }
294
+
295
+ client := ternary (widget .Proxy .client != nil , widget .Proxy .client , defaultHTTPClient )
296
+ response , err := decodeJsonFromRequest [tokenResponse ](client , req )
297
+ if err != nil {
298
+ return fmt .Errorf ("decoding Reddit API response: %v" , err )
299
+ }
300
+
301
+ app .accessToken = response .AccessToken
302
+ app .tokenExpiresAt = time .Now ().Add (time .Duration (response .ExpiresIn ) * time .Second )
303
+
304
+ return nil
305
+ }
0 commit comments