Skip to content

Commit 0d9afcc

Browse files
committed
Additional func to retrieve default branch of an org/repo
1 parent 0e909e3 commit 0d9afcc

File tree

3 files changed

+106
-0
lines changed

3 files changed

+106
-0
lines changed

cmd/generic-autobumper/bumper/bumper.go

+10
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,16 @@ func getLastBumpCommit(gerritAuthor, commitTag string) (string, error) {
681681
return outBuf.String(), nil
682682
}
683683

684+
// CreateOrUpdatePROnDefaultBranch retrieves the default branch for the repository
685+
// and creates/updates a PR on that branch without requiring the caller to specify the branch.
686+
func CreateOrUpdatePROnDefaultBranch(gc github.Client, org, repo string, opts *Options, prTitle, prBody string) error {
687+
repository, err := gc.GetRepo(org, repo)
688+
if err != nil {
689+
return fmt.Errorf("failed to get repository details for %s/%s: %w", org, repo, err)
690+
}
691+
return UpdatePR(gc, opts.GitHubOrg, opts.GitHubRepo, "", opts.GitHubLogin, repository.DefaultBranch, opts.HeadBranchName, true, prTitle, prBody)
692+
}
693+
684694
// getChangeId generates a change ID for the gerrit PR that is deterministic
685695
// rather than being random as is normally preferable.
686696
// In particular this chooses a change ID by hashing the last commit by the

cmd/generic-autobumper/bumper/bumper_test.go

+85
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import (
2424
"testing"
2525

2626
"sigs.k8s.io/prow/pkg/config/secret"
27+
"sigs.k8s.io/prow/pkg/github"
28+
"sigs.k8s.io/prow/pkg/github/fakegithub"
2729
)
2830

2931
func TestValidateOptions(t *testing.T) {
@@ -337,3 +339,86 @@ func TestCDToRootDir(t *testing.T) {
337339
})
338340
}
339341
}
342+
343+
func TestCreateOrUpdatePROnDefaultBranch(t *testing.T) {
344+
testCases := []struct {
345+
name string
346+
org string
347+
repo string
348+
prTitle string
349+
prBody string
350+
existingPRs []*github.PullRequest
351+
expectError bool
352+
errorMsg string
353+
}{
354+
{
355+
name: "no existing PRs",
356+
org: "org",
357+
repo: "repo",
358+
prTitle: "Test Title",
359+
prBody: "Test Body",
360+
existingPRs: []*github.PullRequest{},
361+
expectError: false,
362+
},
363+
{
364+
name: "existing PR with same title",
365+
org: "org",
366+
repo: "repo",
367+
prTitle: "Test Title",
368+
prBody: "Test Body",
369+
existingPRs: []*github.PullRequest{
370+
{
371+
Title: "Test Title",
372+
Body: "Existing PR Body",
373+
},
374+
},
375+
expectError: false,
376+
},
377+
{
378+
name: "error creating PR",
379+
org: "org",
380+
repo: "repo",
381+
prTitle: "Test Title",
382+
prBody: "Test Body",
383+
existingPRs: []*github.PullRequest{
384+
{
385+
Title: "Existing PR Title",
386+
Body: "Existing PR Body",
387+
},
388+
},
389+
expectError: true,
390+
errorMsg: "failed to create pull request",
391+
},
392+
}
393+
394+
for _, tc := range testCases {
395+
t.Run(tc.name, func(t *testing.T) {
396+
fakeClient := &fakegithub.FakeClient{
397+
PullRequests: map[int]*github.PullRequest{},
398+
}
399+
400+
for i, pr := range tc.existingPRs {
401+
fakeClient.PullRequests[i] = pr
402+
}
403+
404+
opts := &Options{
405+
GitHubOrg: tc.org,
406+
GitHubRepo: tc.repo,
407+
GitHubLogin: "login",
408+
HeadBranchName: "head-branch",
409+
}
410+
411+
err := CreateOrUpdatePROnDefaultBranch(fakeClient, tc.org, tc.repo, opts, tc.prTitle, tc.prBody)
412+
413+
if tc.expectError && err == nil {
414+
t.Errorf("Expected to get an error but the result is nil")
415+
}
416+
if !tc.expectError && err != nil {
417+
t.Errorf("Expected to not get an error but got one: %v", err)
418+
}
419+
if tc.expectError && err != nil && !strings.Contains(err.Error(), tc.errorMsg) {
420+
t.Errorf("Expected error message to contain %q but got %v", tc.errorMsg, err)
421+
}
422+
})
423+
}
424+
}

pkg/github/fakegithub/fakegithub.go

+11
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,17 @@ func (f *FakeClient) GetPullRequestChanges(org, repo string, number int) ([]gith
484484
return f.PullRequestChanges[number], nil
485485
}
486486

487+
// ClosePullRequest closes a pull request.
488+
func (f *FakeClient) ClosePullRequest(org, repo string, number int) error {
489+
f.lock.Lock()
490+
defer f.lock.Unlock()
491+
if pr, ok := f.PullRequests[number]; ok {
492+
pr.State = "closed"
493+
return nil
494+
}
495+
return errors.New("not found")
496+
}
497+
487498
// GetRef returns the hash of a ref.
488499
func (f *FakeClient) GetRef(owner, repo, ref string) (string, error) {
489500
return TestRef, nil

0 commit comments

Comments
 (0)