@@ -26,6 +26,7 @@ import (
2626 "github.com/spf13/afero"
2727 "go.uber.org/atomic"
2828 gossh "golang.org/x/crypto/ssh"
29+ "golang.org/x/exp/slices"
2930 "golang.org/x/xerrors"
3031
3132 "cdr.dev/slog"
@@ -42,14 +43,6 @@ const (
4243 // unlikely to shadow other exit codes, which are typically 1, 2, 3, etc.
4344 MagicSessionErrorCode = 229
4445
45- // MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
46- // This is stripped from any commands being executed, and is counted towards connection stats.
47- MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
48- // MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
49- MagicSessionTypeVSCode = "vscode"
50- // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
51- // extension to identify itself.
52- MagicSessionTypeJetBrains = "jetbrains"
5346 // MagicProcessCmdlineJetBrains is a string in a process's command line that
5447 // uniquely identifies it as JetBrains software.
5548 MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
@@ -60,6 +53,29 @@ const (
6053 BlockedFileTransferErrorMessage = "File transfer has been disabled."
6154)
6255
56+ // MagicSessionType is a type that represents the type of session that is being
57+ // established.
58+ type MagicSessionType string
59+
60+ const (
61+ // MagicSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection.
62+ // This is stripped from any commands being executed, and is counted towards connection stats.
63+ MagicSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE"
64+ )
65+
66+ // MagicSessionType enums.
67+ const (
68+ // MagicSessionTypeUnknown means the session type could not be determined.
69+ MagicSessionTypeUnknown MagicSessionType = "unknown"
70+ // MagicSessionTypeSSH is the default session type.
71+ MagicSessionTypeSSH MagicSessionType = "ssh"
72+ // MagicSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself.
73+ MagicSessionTypeVSCode MagicSessionType = "vscode"
74+ // MagicSessionTypeJetBrains is set in the SSH config by the JetBrains
75+ // extension to identify itself.
76+ MagicSessionTypeJetBrains MagicSessionType = "jetbrains"
77+ )
78+
6379// BlockedFileTransferCommands contains a list of restricted file transfer commands.
6480var BlockedFileTransferCommands = []string {"nc" , "rsync" , "scp" , "sftp" }
6581
@@ -255,14 +271,42 @@ func (s *Server) ConnStats() ConnStats {
255271 }
256272}
257273
274+ func extractMagicSessionType (env []string ) (magicType MagicSessionType , rawType string , filteredEnv []string ) {
275+ for _ , kv := range env {
276+ if ! strings .HasPrefix (kv , MagicSessionTypeEnvironmentVariable ) {
277+ continue
278+ }
279+
280+ rawType = strings .TrimPrefix (kv , MagicSessionTypeEnvironmentVariable + "=" )
281+ // Keep going, we'll use the last instance of the env.
282+ }
283+
284+ // Always force lowercase checking to be case-insensitive.
285+ switch MagicSessionType (strings .ToLower (rawType )) {
286+ case MagicSessionTypeVSCode :
287+ magicType = MagicSessionTypeVSCode
288+ case MagicSessionTypeJetBrains :
289+ magicType = MagicSessionTypeJetBrains
290+ case "" , MagicSessionTypeSSH :
291+ magicType = MagicSessionTypeSSH
292+ default :
293+ magicType = MagicSessionTypeUnknown
294+ }
295+
296+ return magicType , rawType , slices .DeleteFunc (env , func (kv string ) bool {
297+ return strings .HasPrefix (kv , MagicSessionTypeEnvironmentVariable + "=" )
298+ })
299+ }
300+
258301func (s * Server ) sessionHandler (session ssh.Session ) {
259302 ctx := session .Context ()
303+ id := uuid .New ()
260304 logger := s .logger .With (
261305 slog .F ("remote_addr" , session .RemoteAddr ()),
262306 slog .F ("local_addr" , session .LocalAddr ()),
263307 // Assigning a random uuid for each session is useful for tracking
264308 // logs for the same ssh session.
265- slog .F ("id" , uuid . NewString ()),
309+ slog .F ("id" , id . String ()),
266310 )
267311 logger .Info (ctx , "handling ssh session" )
268312
@@ -274,16 +318,21 @@ func (s *Server) sessionHandler(session ssh.Session) {
274318 }
275319 defer s .trackSession (session , false )
276320
277- extraEnv := make ([]string , 0 )
278- x11 , hasX11 := session .X11 ()
279- if hasX11 {
280- display , handled := s .x11Handler (session .Context (), x11 )
281- if ! handled {
282- _ = session .Exit (1 )
283- logger .Error (ctx , "x11 handler failed" )
284- return
285- }
286- extraEnv = append (extraEnv , fmt .Sprintf ("DISPLAY=localhost:%d.%d" , display , x11 .ScreenNumber ))
321+ env := session .Environ ()
322+ magicType , magicTypeRaw , env := extractMagicSessionType (env )
323+
324+ switch magicType {
325+ case MagicSessionTypeVSCode :
326+ s .connCountVSCode .Add (1 )
327+ defer s .connCountVSCode .Add (- 1 )
328+ case MagicSessionTypeJetBrains :
329+ // Do nothing here because JetBrains launches hundreds of ssh sessions.
330+ // We instead track JetBrains in the single persistent tcp forwarding channel.
331+ case MagicSessionTypeSSH :
332+ s .connCountSSHSession .Add (1 )
333+ defer s .connCountSSHSession .Add (- 1 )
334+ case MagicSessionTypeUnknown :
335+ logger .Warn (ctx , "invalid magic ssh session type specified" , slog .F ("raw_type" , magicTypeRaw ))
287336 }
288337
289338 if s .fileTransferBlocked (session ) {
@@ -309,7 +358,18 @@ func (s *Server) sessionHandler(session ssh.Session) {
309358 return
310359 }
311360
312- err := s .sessionStart (logger , session , extraEnv )
361+ x11 , hasX11 := session .X11 ()
362+ if hasX11 {
363+ display , handled := s .x11Handler (session .Context (), x11 )
364+ if ! handled {
365+ _ = session .Exit (1 )
366+ logger .Error (ctx , "x11 handler failed" )
367+ return
368+ }
369+ env = append (env , fmt .Sprintf ("DISPLAY=localhost:%d.%d" , display , x11 .ScreenNumber ))
370+ }
371+
372+ err := s .sessionStart (logger , session , env , magicType )
313373 var exitError * exec.ExitError
314374 if xerrors .As (err , & exitError ) {
315375 code := exitError .ExitCode ()
@@ -379,32 +439,8 @@ func (s *Server) fileTransferBlocked(session ssh.Session) bool {
379439 return false
380440}
381441
382- func (s * Server ) sessionStart (logger slog.Logger , session ssh.Session , extraEnv []string ) (retErr error ) {
442+ func (s * Server ) sessionStart (logger slog.Logger , session ssh.Session , env []string , magicType MagicSessionType ) (retErr error ) {
383443 ctx := session .Context ()
384- env := append (session .Environ (), extraEnv ... )
385- var magicType string
386- for index , kv := range env {
387- if ! strings .HasPrefix (kv , MagicSessionTypeEnvironmentVariable ) {
388- continue
389- }
390- magicType = strings .ToLower (strings .TrimPrefix (kv , MagicSessionTypeEnvironmentVariable + "=" ))
391- env = append (env [:index ], env [index + 1 :]... )
392- }
393-
394- // Always force lowercase checking to be case-insensitive.
395- switch magicType {
396- case MagicSessionTypeVSCode :
397- s .connCountVSCode .Add (1 )
398- defer s .connCountVSCode .Add (- 1 )
399- case MagicSessionTypeJetBrains :
400- // Do nothing here because JetBrains launches hundreds of ssh sessions.
401- // We instead track JetBrains in the single persistent tcp forwarding channel.
402- case "" :
403- s .connCountSSHSession .Add (1 )
404- defer s .connCountSSHSession .Add (- 1 )
405- default :
406- logger .Warn (ctx , "invalid magic ssh session type specified" , slog .F ("type" , magicType ))
407- }
408444
409445 magicTypeLabel := magicTypeMetricLabel (magicType )
410446 sshPty , windowSize , isPty := session .Pty ()
@@ -473,7 +509,7 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
473509 }()
474510 go func () {
475511 for sig := range sigs {
476- s . handleSignal (logger , sig , cmd .Process , magicTypeLabel )
512+ handleSignal (logger , sig , cmd .Process , s . metrics , magicTypeLabel )
477513 }
478514 }()
479515 return cmd .Wait ()
@@ -558,7 +594,7 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
558594 sigs = nil
559595 continue
560596 }
561- s . handleSignal (logger , sig , process , magicTypeLabel )
597+ handleSignal (logger , sig , process , s . metrics , magicTypeLabel )
562598 case win , ok := <- windowSize :
563599 if ! ok {
564600 windowSize = nil
@@ -612,15 +648,15 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
612648 return nil
613649}
614650
615- func ( s * Server ) handleSignal (logger slog.Logger , ssig ssh.Signal , signaler interface { Signal (os.Signal ) error }, magicTypeLabel string ) {
651+ func handleSignal (logger slog.Logger , ssig ssh.Signal , signaler interface { Signal (os.Signal ) error }, metrics * sshServerMetrics , magicTypeLabel string ) {
616652 ctx := context .Background ()
617653 sig := osSignalFrom (ssig )
618654 logger = logger .With (slog .F ("ssh_signal" , ssig ), slog .F ("signal" , sig .String ()))
619655 logger .Info (ctx , "received signal from client" )
620656 err := signaler .Signal (sig )
621657 if err != nil {
622658 logger .Warn (ctx , "signaling the process failed" , slog .Error (err ))
623- s . metrics .sessionErrors .WithLabelValues (magicTypeLabel , "yes" , "signal" ).Add (1 )
659+ metrics .sessionErrors .WithLabelValues (magicTypeLabel , "yes" , "signal" ).Add (1 )
624660 }
625661}
626662
0 commit comments