@@ -2,6 +2,7 @@ package github
22
33import (
44 "context"
5+ "crypto/tls"
56 "encoding/json"
67 "fmt"
78 "io"
@@ -18,52 +19,63 @@ import (
1819 "github.com/gptscript-ai/gptscript/pkg/types"
1920)
2021
21- const (
22- GithubPrefix = "github.com/"
23- githubRepoURL = "https://github.com/%s/%s.git"
24- githubDownloadURL = "https://raw.githubusercontent.com/%s/%s/%s/%s"
25- githubCommitURL = "https://api.github.com/repos/%s/%s/commits/%s"
26- )
22+ type GithubConfig struct {
23+ Prefix string
24+ RepoURL string
25+ DownloadURL string
26+ CommitURL string
27+ AuthToken string
28+ }
2729
2830var (
29- githubAuthToken = os .Getenv ("GITHUB_AUTH_TOKEN" )
30- log = mvl .Package ()
31+ log = mvl .Package ()
32+ defaultGithubConfig = & GithubConfig {
33+ Prefix : "github.com/" ,
34+ RepoURL : "https://github.com/%s/%s.git" ,
35+ DownloadURL : "https://raw.githubusercontent.com/%s/%s/%s/%s" ,
36+ CommitURL : "https://api.github.com/repos/%s/%s/commits/%s" ,
37+ AuthToken : os .Getenv ("GITHUB_AUTH_TOKEN" ),
38+ }
3139)
3240
3341func init () {
3442 loader .AddVSC (Load )
3543}
3644
37- func getCommitLsRemote (ctx context.Context , account , repo , ref string ) (string , error ) {
38- url := fmt .Sprintf (githubRepoURL , account , repo )
45+ func getCommitLsRemote (ctx context.Context , account , repo , ref string , config * GithubConfig ) (string , error ) {
46+ url := fmt .Sprintf (config . RepoURL , account , repo )
3947 return git .LsRemote (ctx , url , ref )
4048}
4149
4250// regexp to match a git commit id
4351var commitRegexp = regexp .MustCompile ("^[a-f0-9]{40}$" )
4452
45- func getCommit (ctx context.Context , account , repo , ref string ) (string , error ) {
53+ func getCommit (ctx context.Context , account , repo , ref string , config * GithubConfig ) (string , error ) {
4654 if commitRegexp .MatchString (ref ) {
4755 return ref , nil
4856 }
4957
50- url := fmt .Sprintf (githubCommitURL , account , repo , ref )
58+ url := fmt .Sprintf (config . CommitURL , account , repo , ref )
5159 req , err := http .NewRequestWithContext (ctx , http .MethodGet , url , nil )
5260 if err != nil {
5361 return "" , fmt .Errorf ("failed to create request of %s/%s at %s: %w" , account , repo , url , err )
5462 }
5563
56- if githubAuthToken != "" {
57- req .Header .Add ("Authorization" , "Bearer " + githubAuthToken )
64+ if config . AuthToken != "" {
65+ req .Header .Add ("Authorization" , "Bearer " + config . AuthToken )
5866 }
5967
60- resp , err := http .DefaultClient .Do (req )
68+ client := http .DefaultClient
69+ if req .Host == config .Prefix && strings .ToLower (os .Getenv ("GH_ENTERPRISE_SKIP_VERIFY" )) == "true" {
70+ client = & http.Client {Transport : & http.Transport {TLSClientConfig : & tls.Config {InsecureSkipVerify : true }}}
71+ }
72+ resp , err := client .Do (req )
6173 if err != nil {
6274 return "" , err
6375 } else if resp .StatusCode != http .StatusOK {
6476 c , _ := io .ReadAll (resp .Body )
6577 resp .Body .Close ()
66- commit , fallBackErr := getCommitLsRemote (ctx , account , repo , ref )
78+ commit , fallBackErr := getCommitLsRemote (ctx , account , repo , ref , config )
6779 if fallBackErr == nil {
6880 return commit , nil
6981 }
@@ -88,8 +100,28 @@ func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
88100 return commit .SHA , nil
89101}
90102
91- func Load (ctx context.Context , _ * cache.Client , urlName string ) (string , string , * types.Repo , bool , error ) {
92- if ! strings .HasPrefix (urlName , GithubPrefix ) {
103+ func LoaderForPrefix (prefix string ) func (context.Context , * cache.Client , string ) (string , string , * types.Repo , bool , error ) {
104+ return func (ctx context.Context , c * cache.Client , urlName string ) (string , string , * types.Repo , bool , error ) {
105+ return LoadWithConfig (ctx , c , urlName , NewGithubEnterpriseConfig (prefix ))
106+ }
107+ }
108+
109+ func Load (ctx context.Context , c * cache.Client , urlName string ) (string , string , * types.Repo , bool , error ) {
110+ return LoadWithConfig (ctx , c , urlName , defaultGithubConfig )
111+ }
112+
113+ func NewGithubEnterpriseConfig (prefix string ) * GithubConfig {
114+ return & GithubConfig {
115+ Prefix : prefix ,
116+ RepoURL : fmt .Sprintf ("https://%s/%%s/%%s.git" , prefix ),
117+ DownloadURL : fmt .Sprintf ("https://raw.%s/%%s/%%s/%%s/%%s" , prefix ),
118+ CommitURL : fmt .Sprintf ("https://%s/api/v3/repos/%%s/%%s/commits/%%s" , prefix ),
119+ AuthToken : os .Getenv ("GH_ENTERPRISE_TOKEN" ),
120+ }
121+ }
122+
123+ func LoadWithConfig (ctx context.Context , _ * cache.Client , urlName string , config * GithubConfig ) (string , string , * types.Repo , bool , error ) {
124+ if ! strings .HasPrefix (urlName , config .Prefix ) {
93125 return "" , "" , nil , false , nil
94126 }
95127
@@ -107,12 +139,12 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string,
107139 account , repo := parts [1 ], parts [2 ]
108140 path := strings .Join (parts [3 :], "/" )
109141
110- ref , err := getCommit (ctx , account , repo , ref )
142+ ref , err := getCommit (ctx , account , repo , ref , config )
111143 if err != nil {
112144 return "" , "" , nil , false , err
113145 }
114146
115- downloadURL := fmt .Sprintf (githubDownloadURL , account , repo , ref , path )
147+ downloadURL := fmt .Sprintf (config . DownloadURL , account , repo , ref , path )
116148 if path == "" || path == "/" || ! strings .Contains (parts [len (parts )- 1 ], "." ) {
117149 var (
118150 testPath string
@@ -124,13 +156,20 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string,
124156 } else {
125157 testPath = path + "/" + ext
126158 }
127- testURL = fmt .Sprintf (githubDownloadURL , account , repo , ref , testPath )
159+ testURL = fmt .Sprintf (config . DownloadURL , account , repo , ref , testPath )
128160 if i == len (types .DefaultFiles )- 1 {
129161 // no reason to test the last one, we are just going to use it. Being that the default list is only
130162 // two elements this loop could have been one check, but hey over-engineered code ftw.
131163 break
132164 }
133- if resp , err := http .Head (testURL ); err == nil {
165+ headReq , err := http .NewRequest ("HEAD" , testURL , nil )
166+ if err != nil {
167+ break
168+ }
169+ if config .AuthToken != "" {
170+ headReq .Header .Add ("Authorization" , "Bearer " + config .AuthToken )
171+ }
172+ if resp , err := http .DefaultClient .Do (headReq ); err == nil {
134173 _ = resp .Body .Close ()
135174 if resp .StatusCode == 200 {
136175 break
@@ -141,9 +180,9 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, string,
141180 path = testPath
142181 }
143182
144- return downloadURL , githubAuthToken , & types.Repo {
183+ return downloadURL , config . AuthToken , & types.Repo {
145184 VCS : "git" ,
146- Root : fmt .Sprintf (githubRepoURL , account , repo ),
185+ Root : fmt .Sprintf (config . RepoURL , account , repo ),
147186 Path : gpath .Dir (path ),
148187 Name : gpath .Base (path ),
149188 Revision : ref ,
0 commit comments