@@ -2,10 +2,13 @@ package glance
2
2
3
3
import (
4
4
"context"
5
+ "encoding/base64"
6
+ "encoding/json"
5
7
"errors"
6
8
"fmt"
7
9
"html"
8
10
"html/template"
11
+ "io"
9
12
"net/http"
10
13
"net/url"
11
14
"strings"
19
22
20
23
type redditWidget struct {
21
24
widgetBase `yaml:",inline"`
25
+ redditAccessToken string
26
+ redditAppName string `yaml:"reddit-app-name"`
27
+ redditClientID string `yaml:"reddit-client-id"`
28
+ redditClientSecret string `yaml:"reddit-client-secret"`
22
29
Posts forumPostList `yaml:"-"`
23
30
Subreddit string `yaml:"subreddit"`
24
31
Proxy proxyOptionsField `yaml:"proxy"`
@@ -35,6 +42,67 @@ type redditWidget struct {
35
42
RequestUrlTemplate string `yaml:"request-url-template"`
36
43
}
37
44
45
+ type redditTokenResponse struct {
46
+ AccessToken string `json:"access_token"`
47
+ TokenType string `json:"token_type"`
48
+ ExpiresIn int `json:"expires_in"`
49
+ Scope string `json:"scope"`
50
+ }
51
+
52
+ func (widget * redditWidget ) fetchRedditAccessToken () error {
53
+ // Only execute if a matching configuration is provider
54
+ if widget .redditAppName == "" || widget .redditClientID == "" || widget .redditClientSecret == "" {
55
+ return nil
56
+ }
57
+
58
+ auth := base64 .StdEncoding .EncodeToString ([]byte (widget .redditClientID + ":" + widget .redditClientSecret ))
59
+
60
+ // Prepare form data
61
+ data := url.Values {}
62
+ data .Set ("grant_type" , "client_credentials" )
63
+
64
+ // Create request
65
+ req , err := http .NewRequest ("POST" , "https://www.reddit.com/api/v1/access_token" , strings .NewReader (data .Encode ()))
66
+ if err != nil {
67
+ return err
68
+ }
69
+
70
+ // Set headers
71
+ req .Header .Add ("Authorization" , "Basic " + auth )
72
+ req .Header .Add ("User-Agent" , fmt .Sprintf ("%s/1.0" , widget .redditAppName ))
73
+ req .Header .Add ("Content-Type" , "application/x-www-form-urlencoded" )
74
+
75
+ // Make the request
76
+ client := & http.Client {}
77
+ resp , err := client .Do (req )
78
+ if err != nil {
79
+ return fmt .Errorf ("querying Reddit API: %w" , err )
80
+ }
81
+ defer resp .Body .Close ()
82
+
83
+ // Read response body
84
+ body , err := io .ReadAll (resp .Body )
85
+ if err != nil {
86
+ return fmt .Errorf ("reading response body: %w" , err )
87
+ }
88
+
89
+ // Check for error status code
90
+ if resp .StatusCode != http .StatusOK {
91
+ return fmt .Errorf ("API request failed with status %d: %s" , resp .StatusCode , string (body ))
92
+ }
93
+
94
+ // Parse JSON response
95
+ var tokenResp redditTokenResponse
96
+ err = json .Unmarshal (body , & tokenResp )
97
+ if err != nil {
98
+ return fmt .Errorf ("unmarshalling Reddit API response: %w" , err )
99
+ }
100
+
101
+ widget .redditAccessToken = tokenResp .AccessToken
102
+
103
+ return nil
104
+ }
105
+
38
106
func (widget * redditWidget ) initialize () error {
39
107
if widget .Subreddit == "" {
40
108
return errors .New ("subreddit is required" )
@@ -62,6 +130,10 @@ func (widget *redditWidget) initialize() error {
62
130
}
63
131
}
64
132
133
+ if err := widget .fetchRedditAccessToken (); err != nil {
134
+ return fmt .Errorf ("fetching Reddit API access token: %w" , err )
135
+ }
136
+
65
137
widget .
66
138
withTitle ("r/" + widget .Subreddit ).
67
139
withTitleURL ("https://www.reddit.com/r/" + widget .Subreddit + "/" ).
@@ -97,6 +169,8 @@ func (widget *redditWidget) update(ctx context.Context) {
97
169
widget .RequestUrlTemplate ,
98
170
widget .Proxy .client ,
99
171
widget .ShowFlairs ,
172
+ widget .redditAppName ,
173
+ widget .redditAccessToken ,
100
174
)
101
175
102
176
if ! widget .canContinueUpdateAfterHandlingErr (err ) {
@@ -172,6 +246,8 @@ func fetchSubredditPosts(
172
246
requestUrlTemplate string ,
173
247
proxyClient * http.Client ,
174
248
showFlairs bool ,
249
+ redditAppName string ,
250
+ redditAccessToken string ,
175
251
) (forumPostList , error ) {
176
252
query := url.Values {}
177
253
var requestUrl string
@@ -185,10 +261,18 @@ func fetchSubredditPosts(
185
261
query .Set ("t" , topPeriod )
186
262
}
187
263
264
+ var baseURL string
265
+
266
+ if redditAccessToken != "" {
267
+ baseURL = "https://oauth.reddit.com"
268
+ } else {
269
+ baseURL = "https://www.reddit.com"
270
+ }
271
+
188
272
if search != "" {
189
- requestUrl = fmt .Sprintf ("https://www.reddit.com/ search.json?%s" , query .Encode ())
273
+ requestUrl = fmt .Sprintf ("%s/ search.json?%s" , baseURL , query .Encode ())
190
274
} else {
191
- requestUrl = fmt .Sprintf ("https://www.reddit.com/ r/%s/%s.json?%s" , subreddit , sort , query .Encode ())
275
+ requestUrl = fmt .Sprintf ("%s/ r/%s/%s.json?%s" , baseURL , subreddit , sort , query .Encode ())
192
276
}
193
277
194
278
var client requestDoer = defaultHTTPClient
@@ -205,7 +289,16 @@ func fetchSubredditPosts(
205
289
}
206
290
207
291
// Required to increase rate limit, otherwise Reddit randomly returns 429 even after just 2 requests
208
- setBrowserUserAgentHeader (request )
292
+ if redditAppName == "" {
293
+ setBrowserUserAgentHeader (request )
294
+ } else {
295
+ request .Header .Set ("User-Agent" , fmt .Sprintf ("%s/1.0" , redditAppName ))
296
+ }
297
+
298
+ if redditAccessToken != "" {
299
+ request .Header .Set ("Authorization" , fmt .Sprintf ("Bearer %s" , redditAccessToken ))
300
+ }
301
+
209
302
responseJson , err := decodeJsonFromRequest [subredditResponseJson ](client , request )
210
303
if err != nil {
211
304
return nil , err
0 commit comments