@@ -4,10 +4,14 @@ import (
44 "bufio"
55 "bytes"
66 "context"
7+ "crypto/sha256"
78 _ "embed"
9+ "encoding/hex"
810 "errors"
911 "fmt"
12+ "io"
1013 "io/fs"
14+ "net/http"
1115 "os"
1216 "path/filepath"
1317 "runtime"
@@ -44,6 +48,183 @@ func (r *Runtime) Supports(tool types.Tool, cmd []string) bool {
4448 len (cmd ) > 0 && cmd [0 ] == "${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool"
4549}
4650
51+ type release struct {
52+ account , repo , label string
53+ }
54+
55+ func (r release ) checksumTxt () string {
56+ return fmt .Sprintf (
57+ "https://github.com/%s/%s/releases/download/%s/checksums.txt" ,
58+ r .account ,
59+ r .repo ,
60+ r .label )
61+ }
62+
63+ func (r release ) binURL () string {
64+ return fmt .Sprintf (
65+ "https://github.com/%s/%s/releases/download/%s/%s" ,
66+ r .account ,
67+ r .repo ,
68+ r .label ,
69+ r .srcBinName ())
70+ }
71+
72+ func (r release ) targetBinName () string {
73+ suffix := ""
74+ if runtime .GOOS == "windows" {
75+ suffix = ".exe"
76+ }
77+
78+ return "gptscript-go-tool" + suffix
79+ }
80+
81+ func (r release ) srcBinName () string {
82+ suffix := ""
83+ if runtime .GOOS == "windows" {
84+ suffix = ".exe"
85+ }
86+
87+ return r .repo + "-" +
88+ runtime .GOOS + "-" +
89+ runtime .GOARCH + suffix
90+ }
91+
92+ func getLatestRelease (tool types.Tool ) (* release , bool ) {
93+ if tool .Source .Repo == nil || ! strings .HasPrefix (tool .Source .Repo .Root , "https://github.com/" ) {
94+ return nil , false
95+ }
96+
97+ parts := strings .Split (strings .TrimPrefix (strings .TrimSuffix (tool .Source .Repo .Root , ".git" ), "https://" ), "/" )
98+ if len (parts ) != 3 {
99+ return nil , false
100+ }
101+
102+ client := http.Client {
103+ CheckRedirect : func (_ * http.Request , _ []* http.Request ) error {
104+ return http .ErrUseLastResponse
105+ },
106+ }
107+
108+ resp , err := client .Get (fmt .Sprintf ("https://github.com/%s/%s/releases/latest" , parts [1 ], parts [2 ]))
109+ if err != nil || resp .StatusCode != http .StatusFound {
110+ // ignore error
111+ return nil , false
112+ }
113+ defer resp .Body .Close ()
114+
115+ target := resp .Header .Get ("Location" )
116+ if target == "" {
117+ return nil , false
118+ }
119+
120+ account , repo := parts [1 ], parts [2 ]
121+ parts = strings .Split (target , "/" )
122+ label := parts [len (parts )- 1 ]
123+
124+ return & release {
125+ account : account ,
126+ repo : repo ,
127+ label : label ,
128+ }, true
129+ }
130+
131+ func get (ctx context.Context , url string ) (* http.Response , error ) {
132+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , url , nil )
133+ if err != nil {
134+ return nil , err
135+ }
136+
137+ resp , err := http .DefaultClient .Do (req )
138+ if err != nil {
139+ return nil , err
140+ } else if resp .StatusCode != http .StatusOK {
141+ _ = resp .Body .Close ()
142+ return nil , fmt .Errorf ("bad HTTP status code: %d" , resp .StatusCode )
143+ }
144+
145+ return resp , nil
146+ }
147+
148+ func downloadBin (ctx context.Context , checksum , src , url , bin string ) error {
149+ resp , err := get (ctx , url )
150+ if err != nil {
151+ return err
152+ }
153+ defer resp .Body .Close ()
154+
155+ if err := os .MkdirAll (filepath .Join (src , "bin" ), 0755 ); err != nil {
156+ return err
157+ }
158+
159+ targetFile , err := os .Create (filepath .Join (src , "bin" , bin ))
160+ if err != nil {
161+ return err
162+ }
163+
164+ digest := sha256 .New ()
165+
166+ if _ , err := io .Copy (io .MultiWriter (targetFile , digest ), resp .Body ); err != nil {
167+ return err
168+ }
169+
170+ if err := targetFile .Close (); err != nil {
171+ return nil
172+ }
173+
174+ if got := hex .EncodeToString (digest .Sum (nil )); got != checksum {
175+ return fmt .Errorf ("checksum mismatch %s != %s" , got , checksum )
176+ }
177+
178+ if err := os .Chmod (targetFile .Name (), 0755 ); err != nil {
179+ return err
180+ }
181+
182+ return nil
183+ }
184+
185+ func getChecksum (ctx context.Context , rel * release ) string {
186+ resp , err := get (ctx , rel .checksumTxt ())
187+ if err != nil {
188+ // ignore error
189+ return ""
190+ }
191+ defer resp .Body .Close ()
192+
193+ scan := bufio .NewScanner (resp .Body )
194+ for scan .Scan () {
195+ fields := strings .Fields (scan .Text ())
196+ if len (fields ) != 2 || fields [1 ] != rel .srcBinName () {
197+ continue
198+ }
199+ return fields [0 ]
200+ }
201+
202+ return ""
203+ }
204+
205+ func (r * Runtime ) Binary (ctx context.Context , tool types.Tool , _ , toolSource string , env []string ) (bool , []string , error ) {
206+ if ! tool .Source .IsGit () {
207+ return false , nil , nil
208+ }
209+
210+ rel , ok := getLatestRelease (tool )
211+ if ! ok {
212+ return false , nil , nil
213+ }
214+
215+ checksum := getChecksum (ctx , rel )
216+ if checksum == "" {
217+ return false , nil , nil
218+ }
219+
220+ if err := downloadBin (ctx , checksum , toolSource , rel .binURL (), rel .targetBinName ()); err != nil {
221+ // ignore error
222+ return false , nil , nil
223+ }
224+
225+ return true , env , nil
226+ }
227+
47228func (r * Runtime ) Setup (ctx context.Context , _ types.Tool , dataRoot , toolSource string , env []string ) ([]string , error ) {
48229 binPath , err := r .getRuntime (ctx , dataRoot )
49230 if err != nil {
0 commit comments