Skip to content

Commit 1a1febe

Browse files
committed
refactor(github): split enterprise-specific code
1 parent 7008ad2 commit 1a1febe

7 files changed

+856
-804
lines changed

pkg/sources/github/enterprise.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
// Code used exclusively by the enterprise version.
2+
// https://github.com/trufflesecurity/trufflehog/pull/3298#issuecomment-2510010947
3+
4+
package github
5+
6+
import (
7+
"errors"
8+
"fmt"
9+
"net/http"
10+
"net/url"
11+
"strings"
12+
"sync/atomic"
13+
14+
"github.com/google/go-github/v67/github"
15+
16+
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
17+
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
18+
"github.com/trufflesecurity/trufflehog/v3/pkg/handlers"
19+
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
20+
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
21+
)
22+
23+
// Chunks emits chunks of bytes over a channel.
24+
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error {
25+
chunksReporter := sources.ChanReporter{Ch: chunksChan}
26+
// If targets are provided, we're only scanning the data in those targets.
27+
// Otherwise, we're scanning all data.
28+
// This allows us to only scan the commit where a vulnerability was found.
29+
if len(targets) > 0 {
30+
errs := s.scanTargets(ctx, targets, chunksReporter)
31+
return errors.Join(errs...)
32+
}
33+
34+
// Reset consumption and rate limit metrics on each run.
35+
githubNumRateLimitEncountered.WithLabelValues(s.name).Set(0)
36+
githubSecondsSpentRateLimited.WithLabelValues(s.name).Set(0)
37+
githubReposScanned.WithLabelValues(s.name).Set(0)
38+
39+
// We don't care about handling enumerated values as they happen during
40+
// the normal Chunks flow because we enumerate and scan in two steps.
41+
noopReporter := sources.VisitorReporter{
42+
VisitUnit: func(context.Context, sources.SourceUnit) error {
43+
return nil
44+
},
45+
}
46+
err := s.Enumerate(ctx, noopReporter)
47+
if err != nil {
48+
return fmt.Errorf("error enumerating: %w", err)
49+
}
50+
51+
return s.scan(ctx, chunksReporter)
52+
}
53+
54+
func (s *Source) scan(ctx context.Context, reporter sources.ChunkReporter) error {
55+
var scannedCount uint64 = 1
56+
57+
ctx.Logger().V(2).Info("Found repos to scan", "count", len(s.repos))
58+
59+
// If there is resume information available, limit this scan to only the repos that still need scanning.
60+
reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo)
61+
s.repos = reposToScan
62+
63+
for i, repoURL := range s.repos {
64+
s.jobPool.Go(func() error {
65+
if common.IsDone(ctx) {
66+
return nil
67+
}
68+
ctx := context.WithValue(ctx, "repo", repoURL)
69+
70+
// TODO: set progress complete is being called concurrently with i
71+
s.setProgressCompleteWithRepo(i, progressIndexOffset, repoURL)
72+
// Ensure the repo is removed from the resume info after being scanned.
73+
defer func(s *Source, repoURL string) {
74+
s.resumeInfoMutex.Lock()
75+
defer s.resumeInfoMutex.Unlock()
76+
s.resumeInfoSlice = sources.RemoveRepoFromResumeInfo(s.resumeInfoSlice, repoURL)
77+
}(s, repoURL)
78+
79+
if err := s.scanRepo(ctx, repoURL, reporter); err != nil {
80+
ctx.Logger().Error(err, "error scanning repo")
81+
return nil
82+
}
83+
84+
atomic.AddUint64(&scannedCount, 1)
85+
return nil
86+
})
87+
}
88+
89+
_ = s.jobPool.Wait()
90+
s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "")
91+
92+
return nil
93+
}
94+
95+
func (s *Source) scanTargets(ctx context.Context, targets []sources.ChunkingTarget, reporter sources.ChunkReporter) []error {
96+
var errs []error
97+
for _, tgt := range targets {
98+
if err := s.scanTarget(ctx, tgt, reporter); err != nil {
99+
ctx.Logger().Error(err, "error scanning target")
100+
errs = append(errs, &sources.TargetedScanError{Err: err, SecretID: tgt.SecretID})
101+
}
102+
}
103+
104+
return errs
105+
}
106+
107+
func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, reporter sources.ChunkReporter) error {
108+
metaType, ok := target.QueryCriteria.GetData().(*source_metadatapb.MetaData_Github)
109+
if !ok {
110+
return fmt.Errorf("unable to cast metadata type for targeted scan")
111+
}
112+
meta := metaType.Github
113+
114+
u, err := url.Parse(meta.GetLink())
115+
if err != nil {
116+
return fmt.Errorf("unable to parse GitHub URL: %w", err)
117+
}
118+
119+
// The owner is the second segment and the repo is the third segment of the path.
120+
// Ex: https://github.com/owner/repo/.....
121+
segments := strings.Split(u.Path, "/")
122+
if len(segments) < 3 {
123+
return fmt.Errorf("invalid GitHub URL")
124+
}
125+
126+
readCloser, resp, err := s.connector.APIClient().Repositories.DownloadContents(
127+
ctx,
128+
segments[1],
129+
segments[2],
130+
meta.GetFile(),
131+
&github.RepositoryContentGetOptions{Ref: meta.GetCommit()})
132+
// As of this writing, if the returned readCloser is not nil, it's just the Body of the returned github.Response, so
133+
// there's no need to independently close it.
134+
if resp != nil && resp.Body != nil {
135+
defer resp.Body.Close()
136+
}
137+
if err != nil {
138+
return fmt.Errorf("could not download file for scan: %w", err)
139+
}
140+
if resp.StatusCode != http.StatusOK {
141+
return fmt.Errorf("unexpected HTTP response status when trying to download file for scan: %v", resp.Status)
142+
}
143+
144+
chunkSkel := sources.Chunk{
145+
SourceType: s.Type(),
146+
SourceName: s.name,
147+
SourceID: s.SourceID(),
148+
JobID: s.JobID(),
149+
SecretID: target.SecretID,
150+
SourceMetadata: &source_metadatapb.MetaData{
151+
Data: &source_metadatapb.MetaData_Github{Github: meta},
152+
},
153+
Verify: s.verify,
154+
}
155+
fileCtx := context.WithValues(ctx, "path", meta.GetFile())
156+
return handlers.HandleFile(fileCtx, readCloser, &chunkSkel, reporter)
157+
}

0 commit comments

Comments
 (0)