Skip to content

Commit b42e117

Browse files
committed
refactor(github): split enterprise-specific code
1 parent 7008ad2 commit b42e117

File tree

3 files changed

+160
-140
lines changed

3 files changed

+160
-140
lines changed

pkg/sources/github/enterprise.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
// Code used exclusively by the enterprise version.
2+
// https://github.com/trufflesecurity/trufflehog/pull/3298#issuecomment-2510010947
3+
4+
package github
5+
6+
import (
7+
"errors"
8+
"fmt"
9+
"net/http"
10+
"net/url"
11+
"strings"
12+
"sync/atomic"
13+
14+
"github.com/google/go-github/v67/github"
15+
16+
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
17+
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
18+
"github.com/trufflesecurity/trufflehog/v3/pkg/handlers"
19+
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
20+
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
21+
)
22+
23+
// Chunks emits chunks of bytes over a channel.
24+
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error {
25+
chunksReporter := sources.ChanReporter{Ch: chunksChan}
26+
// If targets are provided, we're only scanning the data in those targets.
27+
// Otherwise, we're scanning all data.
28+
// This allows us to only scan the commit where a vulnerability was found.
29+
if len(targets) > 0 {
30+
errs := s.scanTargets(ctx, targets, chunksReporter)
31+
return errors.Join(errs...)
32+
}
33+
34+
// Reset consumption and rate limit metrics on each run.
35+
githubNumRateLimitEncountered.WithLabelValues(s.name).Set(0)
36+
githubSecondsSpentRateLimited.WithLabelValues(s.name).Set(0)
37+
githubReposScanned.WithLabelValues(s.name).Set(0)
38+
39+
// We don't care about handling enumerated values as they happen during
40+
// the normal Chunks flow because we enumerate and scan in two steps.
41+
noopReporter := sources.VisitorReporter{
42+
VisitUnit: func(context.Context, sources.SourceUnit) error {
43+
return nil
44+
},
45+
}
46+
err := s.Enumerate(ctx, noopReporter)
47+
if err != nil {
48+
return fmt.Errorf("error enumerating: %w", err)
49+
}
50+
51+
return s.scan(ctx, chunksReporter)
52+
}
53+
54+
func (s *Source) scan(ctx context.Context, reporter sources.ChunkReporter) error {
55+
var scannedCount uint64 = 1
56+
57+
ctx.Logger().V(2).Info("Found repos to scan", "count", len(s.repos))
58+
59+
// If there is resume information available, limit this scan to only the repos that still need scanning.
60+
reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo)
61+
s.repos = reposToScan
62+
63+
for i, repoURL := range s.repos {
64+
s.jobPool.Go(func() error {
65+
if common.IsDone(ctx) {
66+
return nil
67+
}
68+
ctx := context.WithValue(ctx, "repo", repoURL)
69+
70+
// TODO: set progress complete is being called concurrently with i
71+
s.setProgressCompleteWithRepo(i, progressIndexOffset, repoURL)
72+
// Ensure the repo is removed from the resume info after being scanned.
73+
defer func(s *Source, repoURL string) {
74+
s.resumeInfoMutex.Lock()
75+
defer s.resumeInfoMutex.Unlock()
76+
s.resumeInfoSlice = sources.RemoveRepoFromResumeInfo(s.resumeInfoSlice, repoURL)
77+
}(s, repoURL)
78+
79+
if err := s.scanRepo(ctx, repoURL, reporter); err != nil {
80+
ctx.Logger().Error(err, "error scanning repo")
81+
return nil
82+
}
83+
84+
atomic.AddUint64(&scannedCount, 1)
85+
return nil
86+
})
87+
}
88+
89+
_ = s.jobPool.Wait()
90+
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "")
91+
92+
return nil
93+
}
94+
95+
func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, reporter sources.ChunkReporter) []error {
96+
var errs []error
97+
for _, tgt := range targets {
98+
if err := s.scanTarget(ctx, tgt, reporter); err != nil {
99+
ctx.Logger().Error(err, "error scanning target")
100+
errs = append(errs, &sources.TargetedScanError{Err: err, SecretID: tgt.SecretID})
101+
}
102+
}
103+
104+
return errs
105+
}
106+
107+
func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, reporter sources.ChunkReporter) error {
108+
metaType, ok := target.QueryCriteria.GetData().(*source_metadatapb.MetaData_Github)
109+
if !ok {
110+
return fmt.Errorf("unable to cast metadata type for targeted scan")
111+
}
112+
meta := metaType.Github
113+
114+
u, err := url.Parse(meta.GetLink())
115+
if err != nil {
116+
return fmt.Errorf("unable to parse GitHub URL: %w", err)
117+
}
118+
119+
// The owner is the second segment and the repo is the third segment of the path.
120+
// Ex: https://github.com/owner/repo/.....
121+
segments := strings.Split(u.Path, "/")
122+
if len(segments) < 3 {
123+
return fmt.Errorf("invalid GitHub URL")
124+
}
125+
126+
readCloser, resp, err := s.connector.APIClient().Repositories.DownloadContents(
127+
ctx,
128+
segments[1],
129+
segments[2],
130+
meta.GetFile(),
131+
&github.RepositoryContentGetOptions{Ref: meta.GetCommit()})
132+
// As of this writing, if the returned readCloser is not nil, it's just the Body of the returned github.Response, so
133+
// there's no need to independently close it.
134+
if resp != nil && resp.Body != nil {
135+
defer resp.Body.Close()
136+
}
137+
if err != nil {
138+
return fmt.Errorf("could not download file for scan: %w", err)
139+
}
140+
if resp.StatusCode != http.StatusOK {
141+
return fmt.Errorf("unexpected HTTP response status when trying to download file for scan: %v", resp.Status)
142+
}
143+
144+
chunkSkel := sources.Chunk{
145+
SourceType: s.Type(),
146+
SourceName: s.name,
147+
SourceID: s.SourceID(),
148+
JobID: s.JobID(),
149+
SecretID: target.SecretID,
150+
SourceMetadata: &source_metadatapb.MetaData{
151+
Data: &source_metadatapb.MetaData_Github{Github: meta},
152+
},
153+
Verify: s.verify,
154+
}
155+
fileCtx := context.WithValues(ctx, "path", meta.GetFile())
156+
return handlers.HandleFile(fileCtx, readCloser, &chunkSkel, reporter)
157+
}

pkg/sources/github/github.go

Lines changed: 0 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@ import (
55
"fmt"
66
"math/rand/v2"
77
"net/http"
8-
"net/url"
98
"os"
109
"regexp"
1110
"sort"
1211
"strings"
1312
"sync"
14-
"sync/atomic"
1513
"time"
1614

1715
"github.com/gobwas/glob"
@@ -23,10 +21,8 @@ import (
2321

2422
"github.com/trufflesecurity/trufflehog/v3/pkg/cache"
2523
"github.com/trufflesecurity/trufflehog/v3/pkg/cache/simple"
26-
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
2724
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
2825
"github.com/trufflesecurity/trufflehog/v3/pkg/giturl"
29-
"github.com/trufflesecurity/trufflehog/v3/pkg/handlers"
3026
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
3127
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
3228
"github.com/trufflesecurity/trufflehog/v3/pkg/sanitizer"
@@ -323,37 +319,6 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) source_metada
323319
return repoInfo.visibility
324320
}
325321

326-
// Chunks emits chunks of bytes over a channel.
327-
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error {
328-
chunksReporter := sources.ChanReporter{Ch: chunksChan}
329-
// If targets are provided, we're only scanning the data in those targets.
330-
// Otherwise, we're scanning all data.
331-
// This allows us to only scan the commit where a vulnerability was found.
332-
if len(targets) > 0 {
333-
errs := s.scanTargets(ctx, targets, chunksReporter)
334-
return errors.Join(errs...)
335-
}
336-
337-
// Reset consumption and rate limit metrics on each run.
338-
githubNumRateLimitEncountered.WithLabelValues(s.name).Set(0)
339-
githubSecondsSpentRateLimited.WithLabelValues(s.name).Set(0)
340-
githubReposScanned.WithLabelValues(s.name).Set(0)
341-
342-
// We don't care about handling enumerated values as they happen during
343-
// the normal Chunks flow because we enumerate and scan in two steps.
344-
noopReporter := sources.VisitorReporter{
345-
VisitUnit: func(context.Context, sources.SourceUnit) error {
346-
return nil
347-
},
348-
}
349-
err := s.Enumerate(ctx, noopReporter)
350-
if err != nil {
351-
return fmt.Errorf("error enumerating: %w", err)
352-
}
353-
354-
return s.scan(ctx, chunksReporter)
355-
}
356-
357322
// Enumerate enumerates the GitHub source based on authentication method and
358323
// user configuration. It populates s.filteredRepoCache, s.repoInfoCache,
359324
// s.memberCache, s.totalRepoSize, s.orgsCache, and s.repos. Additionally,
@@ -624,47 +589,6 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl
624589
return github.NewClient(httpClient).WithEnterpriseURLs(apiEndpoint, apiEndpoint)
625590
}
626591

627-
func (s *Source) scan(ctx context.Context, reporter sources.ChunkReporter) error {
628-
var scannedCount uint64 = 1
629-
630-
ctx.Logger().V(2).Info("Found repos to scan", "count", len(s.repos))
631-
632-
// If there is resume information available, limit this scan to only the repos that still need scanning.
633-
reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo)
634-
s.repos = reposToScan
635-
636-
for i, repoURL := range s.repos {
637-
s.jobPool.Go(func() error {
638-
if common.IsDone(ctx) {
639-
return nil
640-
}
641-
ctx := context.WithValue(ctx, "repo", repoURL)
642-
643-
// TODO: set progress complete is being called concurrently with i
644-
s.setProgressCompleteWithRepo(i, progressIndexOffset, repoURL)
645-
// Ensure the repo is removed from the resume info after being scanned.
646-
defer func(s *Source, repoURL string) {
647-
s.resumeInfoMutex.Lock()
648-
defer s.resumeInfoMutex.Unlock()
649-
s.resumeInfoSlice = sources.RemoveRepoFromResumeInfo(s.resumeInfoSlice, repoURL)
650-
}(s, repoURL)
651-
652-
if err := s.scanRepo(ctx, repoURL, reporter); err != nil {
653-
ctx.Logger().Error(err, "error scanning repo")
654-
return nil
655-
}
656-
657-
atomic.AddUint64(&scannedCount, 1)
658-
return nil
659-
})
660-
}
661-
662-
_ = s.jobPool.Wait()
663-
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "")
664-
665-
return nil
666-
}
667-
668592
// scanRepo attempts to scan the provided URL and any associated wiki and
669593
// comments if configured. An error is returned if we could not find necessary
670594
// repository metadata or clone the repo, otherwise all errors are reported to
@@ -1500,70 +1424,6 @@ func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo
15001424
return nil
15011425
}
15021426

1503-
func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, reporter sources.ChunkReporter) []error {
1504-
var errs []error
1505-
for _, tgt := range targets {
1506-
if err := s.scanTarget(ctx, tgt, reporter); err != nil {
1507-
ctx.Logger().Error(err, "error scanning target")
1508-
errs = append(errs, &sources.TargetedScanError{Err: err, SecretID: tgt.SecretID})
1509-
}
1510-
}
1511-
1512-
return errs
1513-
}
1514-
1515-
func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, reporter sources.ChunkReporter) error {
1516-
metaType, ok := target.QueryCriteria.GetData().(*source_metadatapb.MetaData_Github)
1517-
if !ok {
1518-
return fmt.Errorf("unable to cast metadata type for targeted scan")
1519-
}
1520-
meta := metaType.Github
1521-
1522-
u, err := url.Parse(meta.GetLink())
1523-
if err != nil {
1524-
return fmt.Errorf("unable to parse GitHub URL: %w", err)
1525-
}
1526-
1527-
// The owner is the second segment and the repo is the third segment of the path.
1528-
// Ex: https://github.com/owner/repo/.....
1529-
segments := strings.Split(u.Path, "/")
1530-
if len(segments) < 3 {
1531-
return fmt.Errorf("invalid GitHub URL")
1532-
}
1533-
1534-
readCloser, resp, err := s.connector.APIClient().Repositories.DownloadContents(
1535-
ctx,
1536-
segments[1],
1537-
segments[2],
1538-
meta.GetFile(),
1539-
&github.RepositoryContentGetOptions{Ref: meta.GetCommit()})
1540-
// As of this writing, if the returned readCloser is not nil, it's just the Body of the returned github.Response, so
1541-
// there's no need to independently close it.
1542-
if resp != nil && resp.Body != nil {
1543-
defer resp.Body.Close()
1544-
}
1545-
if err != nil {
1546-
return fmt.Errorf("could not download file for scan: %w", err)
1547-
}
1548-
if resp.StatusCode != http.StatusOK {
1549-
return fmt.Errorf("unexpected HTTP response status when trying to download file for scan: %v", resp.Status)
1550-
}
1551-
1552-
chunkSkel := sources.Chunk{
1553-
SourceType: s.Type(),
1554-
SourceName: s.name,
1555-
SourceID: s.SourceID(),
1556-
JobID: s.JobID(),
1557-
SecretID: target.SecretID,
1558-
SourceMetadata: &source_metadatapb.MetaData{
1559-
Data: &source_metadatapb.MetaData_Github{Github: meta},
1560-
},
1561-
Verify: s.verify,
1562-
}
1563-
fileCtx := context.WithValues(ctx, "path", meta.GetFile())
1564-
return handlers.HandleFile(fileCtx, readCloser, &chunkSkel, reporter)
1565-
}
1566-
15671427
func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error {
15681428
repoURL, _ := unit.SourceUnitID()
15691429
ctx = context.WithValue(ctx, "repo", repoURL)

pkg/sources/sources.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ type Source interface {
7272
// ChunkingTarget parameters, the caller can direct the function to retrieve
7373
// specific chunks of data. This targeted approach allows for efficient and
7474
// intentional data processing, beneficial when verifying or rechecking specific data points.
75+
//
76+
// Deprecated: sources should be migrated to use SourceUnitEnumChunker instead.
77+
// https://github.com/trufflesecurity/trufflehog/pull/3298#issuecomment-2510010947
7578
Chunks(ctx context.Context, chunksChan chan *Chunk, targets ...ChunkingTarget) error
7679
// GetProgress is the completion progress (percentage) for Scanned Source.
7780
GetProgress() *Progress

0 commit comments

Comments
 (0)