Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions internal/animeupdate/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (
)

type Service interface {
Store(ctx context.Context, animeupdate *domain.AnimeUpdate) error
Store(ctx context.Context, userID int, animeupdate *domain.AnimeUpdate) error
GetByID(ctx context.Context, req *domain.GetAnimeUpdateRequest) (*domain.AnimeUpdate, error)
UpdateAnimeList(ctx context.Context, anime *domain.AnimeUpdate, event domain.PlexEvent) error
UpdateAnimeList(ctx context.Context, userID int, anime *domain.AnimeUpdate, event domain.PlexEvent) error
Count(ctx context.Context) (int, error)
GetRecentUnique(ctx context.Context, limit int) ([]*domain.AnimeUpdate, error)
GetRecentUnique(ctx context.Context, userID int, limit int) ([]*domain.AnimeUpdate, error)
GetByPlexID(ctx context.Context, plexID int64) (*domain.AnimeUpdate, error)
GetByPlexIDs(ctx context.Context, plexIDs []int64) ([]*domain.AnimeUpdate, error)
}
Expand All @@ -39,36 +39,36 @@ func NewService(log zerolog.Logger, repo domain.AnimeUpdateRepo, animeSvc anime.
}
}

func (s *service) Store(ctx context.Context, animeupdate *domain.AnimeUpdate) error {
return s.repo.Store(ctx, animeupdate)
func (s *service) Store(ctx context.Context, userID int, animeupdate *domain.AnimeUpdate) error {
return s.repo.Store(ctx, userID, animeupdate)
}

func (s *service) GetByID(ctx context.Context, req *domain.GetAnimeUpdateRequest) (*domain.AnimeUpdate, error) {
return s.repo.GetByID(ctx, req)
}

func (s *service) UpdateAnimeList(ctx context.Context, anime *domain.AnimeUpdate, event domain.PlexEvent) error {
func (s *service) UpdateAnimeList(ctx context.Context, userID int, anime *domain.AnimeUpdate, event domain.PlexEvent) error {
switch event {
case domain.PlexRateEvent:
return s.handleRateEvent(ctx, anime)
return s.handleRateEvent(ctx, userID, anime)
case domain.PlexScrobbleEvent:
return s.handleScrobbleEvent(ctx, anime)
return s.handleScrobbleEvent(ctx, userID, anime)
}
return nil
}

func (s *service) handleRateEvent(ctx context.Context, anime *domain.AnimeUpdate) error {
return s.handleEvent(ctx, anime, anime.UpdateRating, false)
func (s *service) handleRateEvent(ctx context.Context, userID int, anime *domain.AnimeUpdate) error {
return s.handleEvent(ctx, userID, anime, anime.UpdateRating, false)
}

func (s *service) handleScrobbleEvent(ctx context.Context, anime *domain.AnimeUpdate) error {
return s.handleEvent(ctx, anime, anime.UpdateWatchStatus, true)
func (s *service) handleScrobbleEvent(ctx context.Context, userID int, anime *domain.AnimeUpdate) error {
return s.handleEvent(ctx, userID, anime, anime.UpdateWatchStatus, true)
}

func (s *service) handleEvent(ctx context.Context, anime *domain.AnimeUpdate, updateFunc func(context.Context, *mal.Client) error, isScrobble bool) error {
func (s *service) handleEvent(ctx context.Context, userID int, anime *domain.AnimeUpdate, updateFunc func(context.Context, *mal.Client) error, isScrobble bool) error {
if anime.SourceDB == domain.MAL {
anime.MALId = anime.SourceId
return s.updateAndStore(ctx, anime, updateFunc)
return s.updateAndStore(ctx, userID, anime, updateFunc)
}

convertedAnime := s.convertAniDBToTVDB(ctx, anime)
Expand All @@ -78,18 +78,18 @@ func (s *service) handleEvent(ctx context.Context, anime *domain.AnimeUpdate, up
if isScrobble {
anime.EpisodeNum = animeMap.CalculateEpNum(anime.EpisodeNum)
}
return s.updateAndStore(ctx, anime, updateFunc)
return s.updateAndStore(ctx, userID, anime, updateFunc)
}

if anime.SeasonNum == 1 {
return s.updateFromDBAndStore(ctx, anime, updateFunc)
return s.updateFromDBAndStore(ctx, userID, anime, updateFunc)
}

return err
}

func (s *service) updateAndStore(ctx context.Context, anime *domain.AnimeUpdate, updateFunc func(context.Context, *mal.Client) error) error {
client, err := s.malauthService.GetMalClient(ctx)
func (s *service) updateAndStore(ctx context.Context, userID int, anime *domain.AnimeUpdate, updateFunc func(context.Context, *mal.Client) error) error {
client, err := s.malauthService.GetMalClient(ctx, userID)
if err != nil {
return err
}
Expand All @@ -98,10 +98,10 @@ func (s *service) updateAndStore(ctx context.Context, anime *domain.AnimeUpdate,
return err
}
s.log.Info().Interface("status", anime.ListStatus).Msg("MyAnimeList Updated Successfully")
return s.Store(ctx, anime)
return s.Store(ctx, userID, anime)
}

func (s *service) updateFromDBAndStore(ctx context.Context, anime *domain.AnimeUpdate, updateFunc func(context.Context, *mal.Client) error) error {
func (s *service) updateFromDBAndStore(ctx context.Context, userID int, anime *domain.AnimeUpdate, updateFunc func(context.Context, *mal.Client) error) error {
req := &domain.GetAnimeRequest{
IDtype: anime.SourceDB,
Id: anime.SourceId,
Expand All @@ -113,7 +113,7 @@ func (s *service) updateFromDBAndStore(ctx context.Context, anime *domain.AnimeU
}

anime.MALId = animeFromDB.MALId
return s.updateAndStore(ctx, anime, updateFunc)
return s.updateAndStore(ctx, userID, anime, updateFunc)
}

func (s *service) convertAniDBToTVDB(ctx context.Context, anime *domain.AnimeUpdate) *domain.AnimeUpdate {
Expand Down Expand Up @@ -147,8 +147,8 @@ func (s *service) Count(ctx context.Context) (int, error) {
return s.repo.Count(ctx)
}

func (s *service) GetRecentUnique(ctx context.Context, limit int) ([]*domain.AnimeUpdate, error) {
return s.repo.GetRecentUnique(ctx, limit)
func (s *service) GetRecentUnique(ctx context.Context, userID int, limit int) ([]*domain.AnimeUpdate, error) {
return s.repo.GetRecentUnique(ctx, userID, limit)
}

func (s *service) GetByPlexID(ctx context.Context, plexID int64) (*domain.AnimeUpdate, error) {
Expand Down
34 changes: 24 additions & 10 deletions internal/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import (
)

type Service interface {
List(ctx context.Context) ([]domain.APIKey, error)
Store(ctx context.Context, key *domain.APIKey) error
Delete(ctx context.Context, key string) error
List(ctx context.Context, userID int) ([]domain.APIKey, error)
Store(ctx context.Context, userID int, key *domain.APIKey) error
Delete(ctx context.Context, userID int, key string) error
ValidateAPIKey(ctx context.Context, token string) bool
GetUserIDByAPIKey(ctx context.Context, token string) (int, error)
}

type service struct {
Expand All @@ -33,24 +34,27 @@ func NewService(log zerolog.Logger, repo domain.APIRepo) Service {
}
}

func (s *service) List(ctx context.Context) ([]domain.APIKey, error) {
func (s *service) List(ctx context.Context, userID int) ([]domain.APIKey, error) {
if len(s.keyCache) > 0 {
keys := make([]domain.APIKey, 0, len(s.keyCache))

for _, key := range s.keyCache {
keys = append(keys, key)
// Filter by userID when returning from cache
if key.UserID == userID {
keys = append(keys, key)
}
}

return keys, nil
}

return s.repo.GetAllAPIKeys(ctx)
return s.repo.GetAllAPIKeys(ctx, userID)
}

func (s *service) Store(ctx context.Context, apiKey *domain.APIKey) error {
func (s *service) Store(ctx context.Context, userID int, apiKey *domain.APIKey) error {
apiKey.Key = GenerateSecureToken(16)

if err := s.repo.Store(ctx, apiKey); err != nil {
if err := s.repo.Store(ctx, userID, apiKey); err != nil {
return err
}

Expand All @@ -62,13 +66,13 @@ func (s *service) Store(ctx context.Context, apiKey *domain.APIKey) error {
return nil
}

func (s *service) Delete(ctx context.Context, key string) error {
func (s *service) Delete(ctx context.Context, userID int, key string) error {
_, err := s.repo.GetKey(ctx, key)
if err != nil {
return err
}

err = s.repo.Delete(ctx, key)
err = s.repo.Delete(ctx, userID, key)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("could not delete api key: %s", key))
}
Expand Down Expand Up @@ -98,6 +102,16 @@ func (s *service) ValidateAPIKey(ctx context.Context, key string) bool {
return true
}

func (s *service) GetUserIDByAPIKey(ctx context.Context, token string) (int, error) {
// First check cache
if apiKey, ok := s.keyCache[token]; ok {
return apiKey.UserID, nil
}

// If not in cache, get from database
return s.repo.GetUserIDByAPIKey(ctx, token)
}

func GenerateSecureToken(length int) string {
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
Expand Down
27 changes: 27 additions & 0 deletions internal/auth/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ import (
type Service interface {
GetUserCount(ctx context.Context) (int, error)
Login(ctx context.Context, username, password string) (*domain.User, error)
FindByUsername(ctx context.Context, username string) (*domain.User, error)
FindAll(ctx context.Context) ([]*domain.User, error)
CreateUser(ctx context.Context, req domain.CreateUserRequest) error
CreateUserAdmin(ctx context.Context, req domain.CreateUserRequest) error
UpdateUser(ctx context.Context, req domain.UpdateUserRequest) error
Delete(ctx context.Context, username string) error
ResetPassword(ctx context.Context, username, newPassword string) error
CreateHash(password string) (hash string, err error)
ComparePasswordAndHash(password string, hash string) (match bool, err error)
Expand Down Expand Up @@ -66,6 +70,10 @@ func (s *service) Login(ctx context.Context, username, password string) (*domain
return u, nil
}

func (s *service) FindByUsername(ctx context.Context, username string) (*domain.User, error) {
return s.userSvc.FindByUsername(ctx, username)
}

func (s *service) CreateUser(ctx context.Context, req domain.CreateUserRequest) error {
if req.Username == "" {
return errors.New("validation error: empty username supplied")
Expand Down Expand Up @@ -198,3 +206,22 @@ func (s *service) CreateHash(password string) (hash string, err error) {

return argon2id.CreateHash(password, argon2id.DefaultParams)
}

func (s *service) FindAll(ctx context.Context) ([]*domain.User, error) {
return s.userSvc.FindAll(ctx)
}

func (s *service) CreateUserAdmin(ctx context.Context, req domain.CreateUserRequest) error {
// Hash password before storing
hash, err := s.CreateHash(req.Password)
if err != nil {
return errors.Wrap(err, "could not create password hash")
}
req.Password = hash

return s.userSvc.CreateUserAdmin(ctx, req)
}

func (s *service) Delete(ctx context.Context, username string) error {
return s.userSvc.Delete(ctx, username)
}
25 changes: 15 additions & 10 deletions internal/database/animeupdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ func NewAnimeUpdateRepo(log zerolog.Logger, db *DB) domain.AnimeUpdateRepo {
}
}

func (repo *AnimeUpdateRepo) Store(ctx context.Context, r *domain.AnimeUpdate) error {
func (repo *AnimeUpdateRepo) Store(ctx context.Context, userID int, r *domain.AnimeUpdate) error {
// Set the userID in the AnimeUpdate struct
r.UserID = userID

listDetails, err := json.Marshal(r.ListDetails)
if err != nil {
return errors.Wrap(err, "failed to marshal listDetails")
Expand All @@ -36,8 +39,8 @@ func (repo *AnimeUpdateRepo) Store(ctx context.Context, r *domain.AnimeUpdate) e

queryBuilder := repo.db.squirrel.
Insert("anime_update").
Columns("mal_id", "source_db", "source_id", "episode_num", "season_num", "time_stamp", "list_details", "list_status", "plex_id").
Values(r.MALId, r.SourceDB, r.SourceId, r.EpisodeNum, r.SeasonNum, r.Timestamp, string(listDetails), string(listStatus), r.PlexId).
Columns("user_id", "mal_id", "source_db", "source_id", "episode_num", "season_num", "time_stamp", "list_details", "list_status", "plex_id").
Values(r.UserID, r.MALId, r.SourceDB, r.SourceId, r.EpisodeNum, r.SeasonNum, r.Timestamp, string(listDetails), string(listStatus), r.PlexId).
Suffix("RETURNING id").RunWith(repo.db.handler)

var retID int64
Expand Down Expand Up @@ -79,16 +82,18 @@ func (repo *AnimeUpdateRepo) Count(ctx context.Context) (int, error) {
return count, nil
}

func (repo *AnimeUpdateRepo) GetRecentUnique(ctx context.Context, limit int) ([]*domain.AnimeUpdate, error) {
func (repo *AnimeUpdateRepo) GetRecentUnique(ctx context.Context, userID int, limit int) ([]*domain.AnimeUpdate, error) {
latest := repo.db.squirrel.
Select("mal_id, MAX(time_stamp) AS max_ts").
From("anime_update").
Where("user_id = ?", userID).
GroupBy("mal_id")

queryBuilder := repo.db.squirrel.
Select("au.id, au.mal_id, au.source_db, au.source_id, au.episode_num, au.season_num, au.time_stamp, au.list_details, au.list_status, au.plex_id").
Select("au.id, au.user_id, au.mal_id, au.source_db, au.source_id, au.episode_num, au.season_num, au.time_stamp, au.list_details, au.list_status, au.plex_id").
FromSelect(latest, "latest").
Join("anime_update au ON latest.mal_id = au.mal_id AND latest.max_ts = au.time_stamp").
Where("au.user_id = ?", userID).
OrderBy("au.time_stamp DESC").
Limit(uint64(limit))

Expand All @@ -107,7 +112,7 @@ func (repo *AnimeUpdateRepo) GetRecentUnique(ctx context.Context, limit int) ([]
for rows.Next() {
var au domain.AnimeUpdate
var listDetailsBytes, listStatusBytes []byte
if err := rows.Scan(&au.ID, &au.MALId, &au.SourceDB, &au.SourceId, &au.EpisodeNum, &au.SeasonNum, &au.Timestamp, &listDetailsBytes, &listStatusBytes, &au.PlexId); err != nil {
if err := rows.Scan(&au.ID, &au.UserID, &au.MALId, &au.SourceDB, &au.SourceId, &au.EpisodeNum, &au.SeasonNum, &au.Timestamp, &listDetailsBytes, &listStatusBytes, &au.PlexId); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
if err := json.Unmarshal(listDetailsBytes, &au.ListDetails); err != nil {
Expand All @@ -123,7 +128,7 @@ func (repo *AnimeUpdateRepo) GetRecentUnique(ctx context.Context, limit int) ([]

func (repo *AnimeUpdateRepo) GetByPlexID(ctx context.Context, plexID int64) (*domain.AnimeUpdate, error) {
queryBuilder := repo.db.squirrel.
Select("id, mal_id, source_db, source_id, episode_num, season_num, time_stamp, list_details, list_status, plex_id").
Select("id, user_id, mal_id, source_db, source_id, episode_num, season_num, time_stamp, list_details, list_status, plex_id").
From("anime_update").
Where("plex_id = ?", plexID).
OrderBy("time_stamp DESC").
Expand All @@ -137,7 +142,7 @@ func (repo *AnimeUpdateRepo) GetByPlexID(ctx context.Context, plexID int64) (*do
row := repo.db.handler.QueryRowContext(ctx, query, args...)
var au domain.AnimeUpdate
var listDetailsBytes, listStatusBytes []byte
if err := row.Scan(&au.ID, &au.MALId, &au.SourceDB, &au.SourceId, &au.EpisodeNum, &au.SeasonNum, &au.Timestamp, &listDetailsBytes, &listStatusBytes, &au.PlexId); err != nil {
if err := row.Scan(&au.ID, &au.UserID, &au.MALId, &au.SourceDB, &au.SourceId, &au.EpisodeNum, &au.SeasonNum, &au.Timestamp, &listDetailsBytes, &listStatusBytes, &au.PlexId); err != nil {
if err == sql.ErrNoRows {
return nil, nil // No update for this plex_id
}
Expand All @@ -158,7 +163,7 @@ func (repo *AnimeUpdateRepo) GetByPlexIDs(ctx context.Context, plexIDs []int64)
}

queryBuilder := repo.db.squirrel.
Select("id, mal_id, source_db, source_id, episode_num, season_num, time_stamp, list_details, list_status, plex_id").
Select("id, user_id, mal_id, source_db, source_id, episode_num, season_num, time_stamp, list_details, list_status, plex_id").
From("anime_update").
Where(sq.Eq{"plex_id": plexIDs}).
OrderBy("time_stamp DESC")
Expand All @@ -180,7 +185,7 @@ func (repo *AnimeUpdateRepo) GetByPlexIDs(ctx context.Context, plexIDs []int64)
for rows.Next() {
var au domain.AnimeUpdate
var listDetailsBytes, listStatusBytes []byte
if err := rows.Scan(&au.ID, &au.MALId, &au.SourceDB, &au.SourceId, &au.EpisodeNum, &au.SeasonNum, &au.Timestamp, &listDetailsBytes, &listStatusBytes, &au.PlexId); err != nil {
if err := rows.Scan(&au.ID, &au.UserID, &au.MALId, &au.SourceDB, &au.SourceId, &au.EpisodeNum, &au.SeasonNum, &au.Timestamp, &listDetailsBytes, &listStatusBytes, &au.PlexId); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
if err := json.Unmarshal(listDetailsBytes, &au.ListDetails); err != nil {
Expand Down
Loading