Skip to content

Commit 232ea58

Browse files
dido18MatteoPologrutolucarin91
authored
refac(updater): refactor the updater package
Co-authored-by: MatteoPologruto <m.pologruto@arduino.cc> Co-authored-by: lucarin91 <lucarin@protonmail.com>
0 parents  commit 232ea58

File tree

14 files changed

+1173
-0
lines changed

14 files changed

+1173
-0
lines changed

cmd/releaser/main.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package main
2+
3+
import (
4+
"encoding/json"
5+
"flag"
6+
"fmt"
7+
"log"
8+
"os"
9+
"runtime"
10+
11+
"github.com/arduino/go-updater/releaser"
12+
)
13+
14+
func main() {
15+
var (
16+
outputDir string
17+
platform releaser.Platform
18+
)
19+
20+
platform = defaultPlatform()
21+
22+
flag.StringVar(&outputDir, "o", "public", "Output directory for writing updates")
23+
flag.Var(&platform, "platform", "Target platform in the form OS-ARCH. Defaults to running os/arch or the combination of the environment variables GOOS and GOARCH if both are set.")
24+
flag.Usage = printUsage
25+
flag.Parse()
26+
27+
if flag.NArg() < 2 {
28+
flag.Usage()
29+
os.Exit(1)
30+
}
31+
32+
inputPath := flag.Arg(0)
33+
version := flag.Arg(1)
34+
35+
manifest, err := releaser.CreateRelease(inputPath, platform, version, outputDir)
36+
if err != nil {
37+
log.Fatalf("could not create release: %v", err)
38+
}
39+
fmt.Println("Release created successfully!")
40+
jsonBytes, err := json.MarshalIndent(manifest, "", " ")
41+
if err != nil {
42+
log.Fatalf("could not marshal manifest to JSON: %v", err)
43+
}
44+
fmt.Println(string(jsonBytes))
45+
}
46+
47+
func defaultPlatform() releaser.Platform {
48+
goos := os.Getenv("GOOS")
49+
goarch := os.Getenv("GOARCH")
50+
if goos != "" && goarch != "" {
51+
return releaser.NewPlatform(goos, goarch)
52+
}
53+
return releaser.NewPlatform(runtime.GOOS, runtime.GOARCH)
54+
}
55+
56+
func printUsage() {
57+
fmt.Fprintf(os.Stderr, `
58+
Usage:
59+
go-selfupdate [flags] <binary-or-dir> <version>
60+
61+
Positional arguments:
62+
<binary-or-dir> Path to the binary file or directory containing binaries
63+
<version> Version string to embed in the update metadata
64+
65+
Flags:
66+
`)
67+
flag.PrintDefaults()
68+
fmt.Fprintln(os.Stderr, `
69+
Examples:
70+
go-selfupdate myapp 1.2.3
71+
go-selfupdate -o public -platform linux-amd64 myapp 1.2.3
72+
go-selfupdate /tmp/mybinares/ 1.2.3`)
73+
}

releaser/http_client.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package releaser
2+
3+
import (
4+
"crypto/sha256"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
)
11+
12+
// Client holds the base URL, command name, allows custom HTTP client, and optional headers.
13+
type Client struct {
14+
BaseURL *url.URL
15+
CmdName string
16+
HTTPClient HTTPDoer
17+
Headers map[string]string // Optional headers to add to each request
18+
}
19+
20+
// HTTPDoer is an interface for http.Client or mocks.
21+
type HTTPDoer interface {
22+
Do(req *http.Request) (*http.Response, error)
23+
}
24+
25+
// Option is a functional option for configuring Client.
26+
type Option func(*Client)
27+
28+
// WithHeaders sets custom headers for the Client.
29+
func WithHeaders(headers map[string]string) Option {
30+
return func(c *Client) {
31+
c.Headers = headers
32+
}
33+
}
34+
35+
// WithHTTPClient sets a custom HTTP client for the Client.
36+
func WithHTTPClient(client HTTPDoer) Option {
37+
return func(c *Client) {
38+
c.HTTPClient = client
39+
}
40+
}
41+
42+
// NewClient creates a new Client with optional configuration.
43+
func NewClient(baseURL *url.URL, cmdName string, opts ...Option) *Client {
44+
c := &Client{
45+
BaseURL: baseURL,
46+
CmdName: cmdName,
47+
HTTPClient: http.DefaultClient,
48+
Headers: nil,
49+
}
50+
for _, opt := range opts {
51+
opt(c)
52+
}
53+
return c
54+
}
55+
56+
// addHeaders adds custom headers to the request if present.
57+
func (c *Client) addHeaders(req *http.Request) {
58+
for k, v := range c.Headers {
59+
req.Header.Set(k, v)
60+
}
61+
}
62+
63+
// GetManifest fetches and decodes the manifest for the given platform.
64+
func (c *Client) GetManifest(plat Platform) (Manifest, error) {
65+
manifestURL := c.BaseURL.JoinPath(c.CmdName, plat.String()+".json").String()
66+
req, err := http.NewRequest("GET", manifestURL, nil)
67+
if err != nil {
68+
return Manifest{}, fmt.Errorf("failed to create request: %w", err)
69+
}
70+
c.addHeaders(req)
71+
// #nosec G107 -- manifestURL is constructed from trusted config and parameters
72+
resp, err := c.HTTPClient.Do(req)
73+
if err != nil {
74+
return Manifest{}, fmt.Errorf("failed to GET manifest: %w", err)
75+
}
76+
defer resp.Body.Close()
77+
if resp.StatusCode != http.StatusOK {
78+
return Manifest{}, fmt.Errorf("bad http status from %s: %v", manifestURL, resp.Status)
79+
}
80+
81+
var res Manifest
82+
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
83+
return Manifest{}, fmt.Errorf("invalid manifest JSON: %w", err)
84+
}
85+
if len(res.Sha256) != sha256.Size {
86+
return Manifest{}, fmt.Errorf("bad sha256 in manifest: got %d bytes", len(res.Sha256))
87+
}
88+
return res, nil
89+
}
90+
91+
// FetchZip fetches the zip for the given version and platform.
92+
func (c *Client) FetchZip(version string, plat Platform) (io.ReadCloser, error) {
93+
zipURL := c.BaseURL.JoinPath(c.CmdName, version, plat.String()+".zip").String()
94+
req, err := http.NewRequest("GET", zipURL, nil)
95+
if err != nil {
96+
return nil, fmt.Errorf("failed to create request: %w", err)
97+
}
98+
c.addHeaders(req)
99+
// #nosec G107 -- zipURL is constructed from trusted config and parameters
100+
resp, err := c.HTTPClient.Do(req)
101+
if err != nil {
102+
return nil, fmt.Errorf("failed to GET zip: %w", err)
103+
}
104+
if resp.StatusCode != http.StatusOK {
105+
resp.Body.Close()
106+
return nil, fmt.Errorf("bad http status from %s: %v", zipURL, resp.Status)
107+
}
108+
return resp.Body, nil
109+
}

releaser/platform.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package releaser
2+
3+
import (
4+
"errors"
5+
"strings"
6+
)
7+
8+
type Platform struct {
9+
OS string
10+
Arch string
11+
}
12+
13+
func NewPlatform(os string, arch string) Platform {
14+
return Platform{
15+
OS: os,
16+
Arch: arch,
17+
}
18+
}
19+
20+
// Parse parses a string like "linux-amd64" into a Platform struct.
21+
func Parse(s string) (Platform, error) {
22+
parts := strings.Split(s, "-")
23+
if len(parts) != 2 {
24+
return Platform{}, errors.New("platform string must be in the form os-arch, e.g. linux-amd64")
25+
}
26+
os := parts[0]
27+
if os == "" {
28+
return Platform{}, errors.New("missing OS in platform string")
29+
}
30+
arch := parts[1]
31+
if arch == "" {
32+
return Platform{}, errors.New("missing Arch in platform string")
33+
}
34+
return Platform{OS: os, Arch: arch}, nil
35+
}
36+
37+
func MustParse(s string) Platform {
38+
id, err := Parse(s)
39+
if err != nil {
40+
panic(err)
41+
}
42+
return id
43+
}
44+
45+
// String returns the platform as "os-arch"
46+
func (p Platform) String() string {
47+
return p.OS + "-" + p.Arch
48+
}
49+
50+
// Set parses and sets the platform from a string like "linux-amd64"
51+
// Used for flag.Value interface
52+
func (p *Platform) Set(s string) error {
53+
platform, err := Parse(s)
54+
if err != nil {
55+
return err
56+
}
57+
p.OS = platform.OS
58+
p.Arch = platform.Arch
59+
return nil
60+
}

releaser/platform_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package releaser
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestParsePlatform_Valid(t *testing.T) {
8+
tests := []struct {
9+
input string
10+
expected Platform
11+
}{
12+
{"linux-amd64", Platform{"linux", "amd64"}},
13+
{"darwin-arm64", Platform{"darwin", "arm64"}},
14+
{"windows-386", Platform{"windows", "386"}},
15+
}
16+
17+
for _, tt := range tests {
18+
got, err := Parse(tt.input)
19+
if err != nil {
20+
t.Errorf("ParsePlatform(%q) unexpected error: %v", tt.input, err)
21+
}
22+
if got != tt.expected {
23+
t.Errorf("ParsePlatform(%q) = %+v, want %+v", tt.input, got, tt.expected)
24+
}
25+
}
26+
}
27+
28+
func TestParsePlatform_Invalid(t *testing.T) {
29+
cases := []struct {
30+
name string
31+
input string
32+
}{
33+
{"Empty", ""},
34+
{"Missing Arch", "linux"},
35+
{"Missing OS", "-amd64"},
36+
{"Extra Parts", "linux-amd64-extra"},
37+
{"Just Dash", "-"},
38+
}
39+
40+
for _, c := range cases {
41+
_, err := Parse(c.input)
42+
if err == nil {
43+
t.Errorf("ParsePlatform(%q) expected error, got nil", c.input)
44+
}
45+
}
46+
}
47+
48+
func TestPlatform_String(t *testing.T) {
49+
p := Platform{"linux", "amd64"}
50+
if got := p.String(); got != "linux-amd64" {
51+
t.Errorf("Platform.String() = %q, want %q", got, "linux-amd64")
52+
}
53+
}
54+
55+
func TestPlatform_Set(t *testing.T) {
56+
var p Platform
57+
err := p.Set("darwin-arm64")
58+
if err != nil {
59+
t.Fatalf("Set returned error: %v", err)
60+
}
61+
if p.OS != "darwin" || p.Arch != "arm64" {
62+
t.Errorf("Set did not set fields correctly: %+v", p)
63+
}
64+
}
65+
66+
func TestPlatform_Set_Invalid(t *testing.T) {
67+
var p Platform
68+
err := p.Set("badformat")
69+
if err == nil {
70+
t.Error("Set should return error for bad format")
71+
}
72+
}

0 commit comments

Comments
 (0)