Skip to content

Commit 81b9001

Browse files
committed
Implement white/blacklisting for socks and ssh
1 parent cf09f68 commit 81b9001

File tree

7 files changed

+511
-25
lines changed

7 files changed

+511
-25
lines changed

cmd/gost/main.go

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ func main() {
6565
glog.Fatal(err)
6666
}
6767

68+
glog.Info(serverNode)
69+
6870
wg.Add(1)
6971
go func(node gost.ProxyNode) {
7072
defer wg.Done()

node.go

+46-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ type ProxyNode struct {
1919
Transport string // transport: ws/wss/tls/http2/tcp/udp/rtcp/rudp
2020
Remote string // remote address, used by tcp/udp port forwarding
2121
Users []*url.Userinfo // authentication for proxy
22+
Whitelist *Permissions
23+
Blacklist *Permissions
2224
values url.Values
2325
serverName string
2426
conn net.Conn
@@ -36,12 +38,36 @@ func ParseProxyNode(s string) (node ProxyNode, err error) {
3638
return
3739
}
3840

41+
query := u.Query()
42+
3943
node = ProxyNode{
4044
Addr: u.Host,
41-
values: u.Query(),
45+
values: query,
4246
serverName: u.Host,
4347
}
4448

49+
if query.Get("whitelist") != "" {
50+
node.Whitelist, err = ParsePermissions(query.Get("whitelist"))
51+
52+
if err != nil {
53+
glog.Fatal(err)
54+
}
55+
} else {
56+
// By default allow for everyting
57+
node.Whitelist, _ = ParsePermissions("*:*:*")
58+
}
59+
60+
if query.Get("blacklist") != "" {
61+
node.Blacklist, err = ParsePermissions(query.Get("blacklist"))
62+
63+
if err != nil {
64+
glog.Fatal(err)
65+
}
66+
} else {
67+
// By default block nothing
68+
node.Blacklist, _ = ParsePermissions("")
69+
}
70+
4571
if u.User != nil {
4672
node.Users = append(node.Users, u.User)
4773
}
@@ -126,6 +152,24 @@ func (node *ProxyNode) Get(key string) string {
126152
return node.values.Get(key)
127153
}
128154

155+
func (node *ProxyNode) Can(action string, addr string) bool {
156+
host, strport, err := net.SplitHostPort(addr)
157+
158+
if err != nil {
159+
return false
160+
}
161+
162+
port, err := strconv.Atoi(strport)
163+
164+
if err != nil {
165+
return false
166+
}
167+
168+
glog.V(LDEBUG).Infof("Can action: %s, host: %s, port %d", action, host, port)
169+
170+
return node.Whitelist.Can(action, host, port) && !node.Blacklist.Can(action, host, port)
171+
}
172+
129173
func (node *ProxyNode) getBool(key string) bool {
130174
s := node.Get(key)
131175
if b, _ := strconv.ParseBool(s); b {
@@ -162,5 +206,5 @@ func (node *ProxyNode) keyFile() string {
162206
}
163207

164208
func (node ProxyNode) String() string {
165-
return fmt.Sprintf("transport: %s, protocol: %s, addr: %s", node.Transport, node.Protocol, node.Addr)
209+
return fmt.Sprintf("transport: %s, protocol: %s, addr: %s, whitelist: %v, blacklist: %v", node.Transport, node.Protocol, node.Addr, node.Whitelist, node.Blacklist)
166210
}

node_test.go

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package gost
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestNodeDefaultWhitelist(t *testing.T) {
10+
assert := assert.New(t)
11+
12+
node, _ := ParseProxyNode("http2://localhost:8000")
13+
14+
assert.True(node.Can("connect", "google.pl:80"))
15+
assert.True(node.Can("connect", "google.pl:443"))
16+
assert.True(node.Can("connect", "google.pl:22"))
17+
assert.True(node.Can("bind", "google.pl:80"))
18+
assert.True(node.Can("bind", "google.com:80"))
19+
}
20+
21+
func TestNodeWhitelist(t *testing.T) {
22+
assert := assert.New(t)
23+
24+
node, _ := ParseProxyNode("http2://localhost:8000?whitelist=connect:google.pl:80,443")
25+
26+
assert.True(node.Can("connect", "google.pl:80"))
27+
assert.True(node.Can("connect", "google.pl:443"))
28+
assert.False(node.Can("connect", "google.pl:22"))
29+
assert.False(node.Can("bind", "google.pl:80"))
30+
assert.False(node.Can("bind", "google.com:80"))
31+
}
32+
33+
func TestNodeBlacklist(t *testing.T) {
34+
assert := assert.New(t)
35+
36+
node, _ := ParseProxyNode("http2://localhost:8000?blacklist=connect:google.pl:80,443")
37+
38+
assert.False(node.Can("connect", "google.pl:80"))
39+
assert.False(node.Can("connect", "google.pl:443"))
40+
assert.True(node.Can("connect", "google.pl:22"))
41+
assert.True(node.Can("bind", "google.pl:80"))
42+
assert.True(node.Can("bind", "google.com:80"))
43+
}

permissions.go

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
package gost
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"strconv"
7+
"strings"
8+
9+
glob "github.com/ryanuber/go-glob"
10+
)
11+
12+
type PortRange struct {
13+
Min, Max int
14+
}
15+
16+
type PortSet []PortRange
17+
18+
type StringSet []string
19+
20+
type Permission struct {
21+
Actions StringSet
22+
Hosts StringSet
23+
Ports PortSet
24+
}
25+
26+
type Permissions []Permission
27+
28+
func minint(x, y int) int {
29+
if x < y {
30+
return x
31+
}
32+
return y
33+
}
34+
35+
func maxint(x, y int) int {
36+
if x > y {
37+
return x
38+
}
39+
return y
40+
}
41+
42+
func (ir *PortRange) Contains(value int) bool {
43+
return value >= ir.Min && value <= ir.Max
44+
}
45+
46+
func ParsePortRange(s string) (*PortRange, error) {
47+
if s == "*" {
48+
return &PortRange{Min: 0, Max: 65535}, nil
49+
}
50+
51+
minmax := strings.Split(s, "-")
52+
switch len(minmax) {
53+
case 1:
54+
port, err := strconv.Atoi(s)
55+
if err != nil {
56+
return nil, err
57+
}
58+
if port < 0 || port > 65535 {
59+
return nil, fmt.Errorf("invalid port: %s", s)
60+
}
61+
return &PortRange{Min: port, Max: port}, nil
62+
case 2:
63+
min, err := strconv.Atoi(minmax[0])
64+
if err != nil {
65+
return nil, err
66+
}
67+
max, err := strconv.Atoi(minmax[1])
68+
if err != nil {
69+
return nil, err
70+
}
71+
72+
realmin := maxint(0, minint(min, max))
73+
realmax := minint(65535, maxint(min, max))
74+
75+
return &PortRange{Min: realmin, Max: realmax}, nil
76+
default:
77+
return nil, fmt.Errorf("invalid range: %s", s)
78+
}
79+
}
80+
81+
func (ps *PortSet) Contains(value int) bool {
82+
for _, portRange := range *ps {
83+
if portRange.Contains(value) {
84+
return true
85+
}
86+
}
87+
88+
return false
89+
}
90+
91+
func ParsePortSet(s string) (*PortSet, error) {
92+
ps := &PortSet{}
93+
94+
if s == "" {
95+
return nil, errors.New("must specify at least one port")
96+
}
97+
98+
ranges := strings.Split(s, ",")
99+
100+
for _, r := range ranges {
101+
portRange, err := ParsePortRange(r)
102+
103+
if err != nil {
104+
return nil, err
105+
}
106+
107+
*ps = append(*ps, *portRange)
108+
}
109+
110+
return ps, nil
111+
}
112+
113+
func (ss *StringSet) Contains(subj string) bool {
114+
for _, s := range *ss {
115+
if glob.Glob(s, subj) {
116+
return true
117+
}
118+
}
119+
120+
return false
121+
}
122+
123+
func ParseStringSet(s string) (*StringSet, error) {
124+
ss := &StringSet{}
125+
if s == "" {
126+
return nil, errors.New("cannot be empty")
127+
}
128+
129+
*ss = strings.Split(s, ",")
130+
131+
return ss, nil
132+
}
133+
134+
func (ps *Permissions) Can(action string, host string, port int) bool {
135+
for _, p := range *ps {
136+
if p.Actions.Contains(action) && p.Hosts.Contains(host) && p.Ports.Contains(port) {
137+
return true
138+
}
139+
}
140+
141+
return false
142+
}
143+
144+
func ParsePermissions(s string) (*Permissions, error) {
145+
ps := &Permissions{}
146+
147+
if s == "" {
148+
return &Permissions{}, nil
149+
}
150+
151+
perms := strings.Split(s, "+")
152+
153+
for _, perm := range perms {
154+
parts := strings.Split(perm, ":")
155+
156+
switch len(parts) {
157+
case 3:
158+
actions, err := ParseStringSet(parts[0])
159+
160+
if err != nil {
161+
return nil, fmt.Errorf("action list must look like connect,bind given: %s", parts[0])
162+
}
163+
164+
hosts, err := ParseStringSet(parts[1])
165+
166+
if err != nil {
167+
return nil, fmt.Errorf("hosts list must look like google.pl,*.google.com given: %s", parts[1])
168+
}
169+
170+
ports, err := ParsePortSet(parts[2])
171+
172+
if err != nil {
173+
return nil, fmt.Errorf("ports list must look like 80,8000-9000, given: %s", parts[2])
174+
}
175+
176+
permission := Permission{Actions: *actions, Hosts: *hosts, Ports: *ports}
177+
178+
*ps = append(*ps, permission)
179+
default:
180+
return nil, fmt.Errorf("permission must have format [actions]:[hosts]:[ports] given: %s", perm)
181+
}
182+
}
183+
184+
return ps, nil
185+
}

0 commit comments

Comments
 (0)