Skip to content

Commit

Permalink
Use callback to create net.Listener (#1)
Browse files Browse the repository at this point in the history
Co-authored-by: Dean Sheather <[email protected]>
  • Loading branch information
samchouse and deansheather committed Sep 5, 2024
1 parent 9e6b773 commit f86b780
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 26 deletions.
9 changes: 7 additions & 2 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ssh

import (
"crypto/subtle"
"errors"
"net"

gossh "golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -29,6 +30,9 @@ const (
// DefaultHandler is the default Handler used by Serve.
var DefaultHandler Handler

// ErrReject is returned by some callbacks to reject a request.
var ErrRejected = errors.New("rejected")

// Option is a functional option handler for Server.
type Option func(*Server) error

Expand Down Expand Up @@ -69,8 +73,9 @@ type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort u
type LocalUnixForwardingCallback func(ctx Context, socketPath string) bool

// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding
// ([email protected]).
type ReverseUnixForwardingCallback func(ctx Context, socketPath string) bool
// ([email protected]). Returning ErrRejected will reject the
// request.
type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error)

// ServerConfigCallback is a hook for creating custom default server configs
type ServerConfigCallback func(ctx Context) *gossh.ServerConfig
Expand Down
53 changes: 33 additions & 20 deletions streamlocal.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g
return false, nil
}

if srv.ReverseUnixForwardingCallback == nil || !srv.ReverseUnixForwardingCallback(ctx, reqPayload.SocketPath) {
if srv.ReverseUnixForwardingCallback == nil {
return false, []byte("unix forwarding is disabled")
}

Expand All @@ -123,26 +123,11 @@ func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *g
return false, nil
}

// Create socket parent dir if not exists.
parentDir := filepath.Dir(addr)
err = os.MkdirAll(parentDir, 0700)
if err != nil {
// TODO: log mkdir failure
return false, nil
}

// Remove existing socket if it exists. We do not use os.Remove() here
// so that directories are kept. Note that it's possible that we will
// overwrite a regular file here. Both of these behaviors match OpenSSH,
// however, which is why we unlink.
err = unlink(addr)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
// TODO: log unlink failure
return false, nil
}

ln, err := net.Listen("unix", addr)
ln, err := srv.ReverseUnixForwardingCallback(ctx, addr)
if err != nil {
if errors.Is(err, ErrRejected) {
return false, []byte("unix forwarding is disabled")
}
// TODO: log unix listen failure
return false, nil
}
Expand Down Expand Up @@ -227,3 +212,31 @@ func unlink(path string) error {
}
}
}

// SimpleUnixReverseForwardingCallback provides a basic implementation for
// ReverseUnixForwardingCallback. The parent directory will be created (with
// os.MkdirAll), and existing files with the same name will be removed.
func SimpleUnixReverseForwardingCallback(_ Context, socketPath string) (net.Listener, error) {
// Create socket parent dir if not exists.
parentDir := filepath.Dir(socketPath)
err := os.MkdirAll(parentDir, 0700)
if err != nil {
return nil, fmt.Errorf("failed to create parent directory %q for socket %q: %w", parentDir, socketPath, err)
}

// Remove existing socket if it exists. We do not use os.Remove() here
// so that directories are kept. Note that it's possible that we will
// overwrite a regular file here. Both of these behaviors match OpenSSH,
// however, which is why we unlink.
err = unlink(socketPath)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("failed to remove existing file in socket path %q: %w", socketPath, err)
}

ln, err := net.Listen("unix", socketPath)
if err != nil {
return nil, fmt.Errorf("failed to listen on unix socket %q: %w", socketPath, err)
}

return ln, err
}
8 changes: 4 additions & 4 deletions streamlocal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ func TestReverseUnixForwardingWorks(t *testing.T) {

_, client, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {},
ReverseUnixForwardingCallback: func(ctx Context, socketPath string) bool {
ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) {
if socketPath != remoteSocketPath {
panic("unexpected socket path: " + socketPath)
}
return true
return SimpleUnixReverseForwardingCallback(ctx, socketPath)
},
}, nil)
defer cleanup()
Expand Down Expand Up @@ -182,12 +182,12 @@ func TestReverseUnixForwardingRespectsCallback(t *testing.T) {
var called int64
_, client, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {},
ReverseUnixForwardingCallback: func(ctx Context, socketPath string) bool {
ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) {
atomic.AddInt64(&called, 1)
if socketPath != remoteSocketPath {
panic("unexpected socket path: " + socketPath)
}
return false
return nil, ErrRejected
},
}, nil)
defer cleanup()
Expand Down

0 comments on commit f86b780

Please sign in to comment.