Skip to content

Commit 55af9d0

Browse files
committed
feat: enhance security and dynamic redirection
1 parent 13fe70d commit 55af9d0

File tree

13 files changed

+177
-43
lines changed

13 files changed

+177
-43
lines changed

.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ DB_DIALECT
2323

2424
DISCORD_CLIENT_ID
2525
DISCORD_CLIENT_SECRET
26-
DISCORD_REDIRECT_URI
26+
DISCORD_BASE_URL
2727
DISCORD_OAUTH2_REDIRECT_DOMAIN
2828
DISCORD_LATEST_VERSION
2929

config/discord.go

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,10 @@ package config
22

33
import "fmt"
44

5-
const (
6-
DiscordBaseURL = "https://discord.com/api"
7-
DiscordOAuthURL = DiscordBaseURL + "/oauth2/authorize"
8-
)
9-
105
type DiscordConfig struct {
116
ClientID string `env:"DISCORD_CLIENT_ID" envDefault:""`
127
ClientSecret string `env:"DISCORD_CLIENT_SECRET" envDefault:""`
13-
RedirectURI string `env:"DISCORD_REDIRECT_URI" envDefault:""`
8+
BaseURL string `env:"DISCORD_BASE_URL" envDefault:"https://discord.com/api"`
149
RedirectDomain string `env:"DISCORD_REDIRECT_DOMAIN" envDefault:".mocha-bot.xyz"`
1510
LatestVersion string `env:"DISCORD_LATEST_VERSION" envDefault:"v10"`
1611
}
@@ -20,9 +15,5 @@ func (d DiscordConfig) GetBaseURL(version string) string {
2015
version = d.LatestVersion
2116
}
2217

23-
return fmt.Sprintf("%s/%s", DiscordBaseURL, version)
24-
}
25-
26-
func (d DiscordConfig) GetRedirectURI() string {
27-
return d.RedirectURI
18+
return fmt.Sprintf("%s/%s", d.BaseURL, version)
2819
}

core/entity/discord.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,20 @@ const (
1414
)
1515

1616
type OauthCallbackRequest struct {
17-
Code string `query:"code"`
17+
Code string `query:"code"`
18+
RequestURL string `query:"request_url"`
19+
RedirectURL string `query:"redirect_url"`
20+
}
21+
22+
func (ocr *OauthCallbackRequest) Validate() error {
23+
err := validator.New().Struct(ocr)
24+
if err != nil {
25+
for _, e := range err.(validator.ValidationErrors) {
26+
return fmt.Errorf("field %s is invalid", e.Field())
27+
}
28+
}
29+
30+
return nil
1831
}
1932

2033
type AccessToken struct {

core/entity/user.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@ package entity
33
type User struct {
44
ID string
55
Username string
6-
Discriminator string
76
Avatar string
7+
Discriminator string
8+
PublicFlags int
9+
Flags int
10+
Banner string
11+
AccentColor string
12+
GlobalName string
13+
MFAEnabled bool
14+
Locale string
15+
PremiumType int
816
Email string
17+
Verified bool
918
}

core/module/discord.go

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
type DiscordUsecase interface {
12-
ExchangeCodeForToken(ctx context.Context, code string) (*entity.AccessToken, error)
12+
ExchangeCodeForToken(ctx context.Context, req *entity.OauthCallbackRequest) (*entity.AccessToken, error)
1313
ExchangeRefreshForToken(ctx context.Context, req *entity.RefreshTokenRequest) (*entity.AccessToken, error)
1414
RevokeToken(ctx context.Context, req *entity.RevokeTokenRequest) error
1515

@@ -26,30 +26,20 @@ func NewDiscordUsecase(repo repository.DiscordRepository) DiscordUsecase {
2626
}
2727
}
2828

29-
func (d *discordUsecase) ExchangeCodeForToken(ctx context.Context, code string) (*entity.AccessToken, error) {
30-
if code == "" {
31-
return nil, fmt.Errorf("%w, invalid code", entity.ErrorUnauthorized)
32-
}
33-
34-
accessToken, err := d.DiscordRepository.GetToken(ctx, code)
35-
if err != nil {
36-
return nil, err
29+
func (d *discordUsecase) ExchangeCodeForToken(ctx context.Context, req *entity.OauthCallbackRequest) (*entity.AccessToken, error) {
30+
if err := req.Validate(); err != nil {
31+
return nil, fmt.Errorf("%w, %w", entity.ErrorUnauthorized, err)
3732
}
3833

39-
return accessToken, nil
34+
return d.DiscordRepository.GetToken(ctx, req.Code, req.RequestURL)
4035
}
4136

4237
func (d *discordUsecase) ExchangeRefreshForToken(ctx context.Context, req *entity.RefreshTokenRequest) (*entity.AccessToken, error) {
4338
if err := req.Validate(); err != nil {
4439
return nil, fmt.Errorf("%w, %w", entity.ErrorUnauthorized, err)
4540
}
4641

47-
accessToken, err := d.DiscordRepository.GetTokenByRefresh(ctx, req.RefreshToken)
48-
if err != nil {
49-
return nil, err
50-
}
51-
52-
return accessToken, nil
42+
return d.DiscordRepository.GetTokenByRefresh(ctx, req.RefreshToken)
5343
}
5444

5545
func (d *discordUsecase) RevokeToken(ctx context.Context, req *entity.RevokeTokenRequest) error {

core/repository/discord/discord.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
)
88

99
type DiscordRepository interface {
10-
GetToken(ctx context.Context, code string) (*entity.AccessToken, error)
10+
GetToken(ctx context.Context, code, redirectURL string) (*entity.AccessToken, error)
1111
GetTokenByRefresh(ctx context.Context, refreshToken string) (*entity.AccessToken, error)
1212
RevokeToken(ctx context.Context, req *entity.RevokeTokenRequest) error
1313

handler/http/discord.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,20 @@ func NewDiscordHandler(cfg config.Config, discordUsecase module.DiscordUsecase)
3131
func (d *discordHandler) OauthCallback(c echo.Context) error {
3232
ctx := c.Request().Context()
3333

34+
// If the error occurs, the discord will still got redirected to the desired URL
35+
// with the error message in the query params and the URI fragment will also be appended
36+
// e.g. /redirect?error=error_message#access_token=token&token_type=type&expires_in=3600&refresh_token=refresh&scope=identify
37+
3438
req, err := parseOauthCallbackRequest(c)
3539
if err != nil {
36-
return c.JSON(parseOauthCallbackError(err))
40+
code, resp := parseOauthCallbackError(err)
41+
return c.Redirect(http.StatusTemporaryRedirect, parseOauthCallbackRedirectError(req.RedirectURL, code, resp))
3742
}
3843

39-
exchanged, err := d.discordUsecase.ExchangeCodeForToken(ctx, req.Code)
44+
exchanged, err := d.discordUsecase.ExchangeCodeForToken(ctx, req)
4045
if err != nil {
41-
return c.JSON(parseOauthCallbackError(err))
46+
code, resp := parseOauthCallbackError(err)
47+
return c.Redirect(http.StatusTemporaryRedirect, parseOauthCallbackRedirectError(req.RedirectURL, code, resp))
4248
}
4349

4450
isLocalhost := d.cfg.App.IsLocalhost()
@@ -52,7 +58,7 @@ func (d *discordHandler) OauthCallback(c echo.Context) error {
5258
c.SetCookie(cookie)
5359
}
5460

55-
return c.Redirect(http.StatusFound, d.cfg.App.Gateway)
61+
return c.Redirect(http.StatusFound, req.RedirectURL)
5662
}
5763

5864
func (d *discordHandler) RefreshToken(c echo.Context) error {

handler/http/parser.discord.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ import (
44
"errors"
55
"fmt"
66
"net/http"
7+
"net/url"
78

89
"github.com/labstack/echo/v4"
910
"github.com/mocha-bot/mochus/core/entity"
1011
cookiey "github.com/mocha-bot/mochus/pkg/cookie"
12+
"github.com/mocha-bot/mochus/pkg/echoy"
1113
zLog "github.com/rs/zerolog/log"
1214
)
1315

@@ -18,19 +20,56 @@ func parseOauthCallbackRequest(c echo.Context) (*entity.OauthCallbackRequest, er
1820
return nil, fmt.Errorf("%w: %s", entity.ErrorBind, err)
1921
}
2022

23+
parsedURL, err := url.ParseRequestURI(c.Request().RequestURI)
24+
if err != nil {
25+
return nil, fmt.Errorf("%w: %s", entity.ErrorBind, err)
26+
}
27+
28+
redirectURL := parsedURL.Query().Get(RedirectURLKey)
29+
if redirectURL == "" {
30+
return nil, fmt.Errorf("%w: %s", entity.ErrorBind, "redirect_url is required")
31+
}
32+
33+
finalURL := url.URL{
34+
Scheme: echoy.GetScheme(c),
35+
Host: c.Request().Host,
36+
Path: c.Request().URL.Path,
37+
RawQuery: url.Values{RedirectURLKey: {redirectURL}}.Encode(),
38+
}
39+
40+
req.RequestURL, err = url.QueryUnescape(finalURL.String())
41+
if err != nil {
42+
return nil, fmt.Errorf("%w: %s", entity.ErrorBind, err)
43+
}
44+
2145
return req, nil
2246
}
2347

2448
func parseOauthCallbackError(err error) (code int, i any) {
2549
switch {
2650
case errors.Is(err, entity.ErrorBind):
2751
return http.StatusBadRequest, Response{Message: err.Error()}
52+
case errors.Is(err, entity.ErrorUnauthorized):
53+
return http.StatusUnauthorized, Response{Message: err.Error()}
54+
case errors.Is(err, entity.ErrorBadRequest):
55+
return http.StatusBadRequest, Response{Message: err.Error()}
2856
default:
2957
zLog.Error().Err(err).Msg("Internal server error")
3058
return http.StatusInternalServerError, Response{Message: "Internal server error"}
3159
}
3260
}
3361

62+
func parseOauthCallbackRedirectError(redirectURL string, code int, i any) (newRedirectURL string) {
63+
redirectURLParsed, _ := url.Parse(redirectURL)
64+
65+
if i == nil {
66+
return redirectURLParsed.String()
67+
}
68+
69+
redirectURLParsed.RawQuery = url.Values{"error": {i.(Response).Message}}.Encode()
70+
return redirectURLParsed.String()
71+
}
72+
3473
func parseRefreshTokenRequest(c echo.Context) (*entity.RefreshTokenRequest, error) {
3574
req := new(entity.RefreshTokenRequest)
3675

@@ -48,6 +87,8 @@ func parseRefreshTokenError(err error) (code int, i any) {
4887
switch {
4988
case errors.Is(err, entity.ErrorUnauthorized):
5089
return http.StatusUnauthorized, Response{Message: err.Error()}
90+
case errors.Is(err, entity.ErrorBadRequest):
91+
return http.StatusBadRequest, Response{Message: err.Error()}
5192
default:
5293
zLog.Error().Err(err).Msg("Internal server error")
5394
return http.StatusInternalServerError, Response{Message: "Internal server error"}
@@ -76,6 +117,8 @@ func parseRevokeTokenError(err error) (code int, i any) {
76117
switch {
77118
case errors.Is(err, entity.ErrorUnauthorized):
78119
return http.StatusUnauthorized, Response{Message: err.Error()}
120+
case errors.Is(err, entity.ErrorBadRequest):
121+
return http.StatusBadRequest, Response{Message: err.Error()}
79122
default:
80123
zLog.Error().Err(err).Msg("Internal server error")
81124
return http.StatusInternalServerError, Response{Message: "Internal server error"}
@@ -106,6 +149,8 @@ func parseGetUserByTokenError(err error) (code int, i any) {
106149
switch {
107150
case errors.Is(err, entity.ErrorUnauthorized):
108151
return http.StatusUnauthorized, Response{Message: err.Error()}
152+
case errors.Is(err, entity.ErrorBadRequest):
153+
return http.StatusBadRequest, Response{Message: err.Error()}
109154
default:
110155
zLog.Error().Err(err).Msg("Internal server error")
111156
return http.StatusInternalServerError, Response{Message: "Internal server error"}

handler/http/request.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package http_handler
2+
3+
const (
4+
RedirectURLKey = "redirect_url"
5+
)

pkg/echoy/scheme.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package echoy
2+
3+
import "github.com/labstack/echo/v4"
4+
5+
const (
6+
HTTP = "http"
7+
HTTPS = "https"
8+
)
9+
10+
func GetScheme(c echo.Context) string {
11+
if c.Request().TLS != nil {
12+
return HTTPS
13+
}
14+
return HTTP
15+
}

repository/discord/api.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package discord_repository
2+
3+
const (
4+
Oauth2Authorize = "/oauth2/authorize"
5+
Oauth2GetToken = "/oauth2/token"
6+
Oauth2RevokeToken = "/oauth2/token/revoke"
7+
8+
GetUser = "/users/@me"
9+
)

0 commit comments

Comments
 (0)