Original post

The very first message sent in a TLS connection is the Client Hello record, in which the client greets the server and tells it, among other things, the server name it wants to connect to. This is called Server Name Indication, or SNI for short, and it’s quite handy as it allows many different servers to be co-located on a single IP address.

The server name is sent in plaintext, which is unfortunately really bad for privacy and censorship resistance, but does enable something very useful: a proxy server can read the server name and use it to decide where to route the connection, without having to decrypt the connection. You can leverage this to make many different physical servers accessible from the Internet even if you have only one public IPv4 address: the proxy listens on your public IP address and forwards connections to the appropriate private IP address based on the SNI.

I just finished writing such a proxy server, which I plan to run on my home network’s router so that I can easily access my internal servers from anywhere on the Internet, without a VPN or SSH port forwarding. I was pleased by how easy it was to write this proxy server using only ’s standard library. It’s a great example of how well-suited Go is for programs involving networking and cryptography.

Let’s start with a standard listen/accept loop (right out of the examples for Go’s net package):

func main() {
        l, err := net.Listen("tcp", ":443")
        if err != nil {
                log.Fatal(err)
        }
        for {
                conn, err := l.Accept()
                if err != nil {
                        log.Print(err)
                        continue
                }
                go handleConnection(conn)
        }
}

Here’s a sketch of the handleConnection function, which reads the Client Hello record from the client, dials the backend server indicated by the Client Hello, and then proxies the client to and from the backend. (Note that we dial the backend using the SNI value, which works well with split-horizon DNS where the proxy sees the backend’s private IP address and external clients see the proxy’s public IP address. If that doesn’t work for you, can use more complicated routing logic.)

func handleConnection(clientConn net.Conn) {
        defer clientConn.Close()

        // ... read Client Hello from clientConn ...

        backendConn, err := net.Dial("tcp", net.JoinHostPort(clientHello.ServerName, "443"))
        if err != nil {
                log.Print(err)
                return
        }
        defer backendConn.Close()

        // ... proxy clientConn <==> backendConn ...
}

Let’s assume for now we have a convenient function to read a Client Hello record from an io.Reader and return a tls.ClientHelloInfo:

func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error)

We can’t simply call this function from handleConnection, because once the Client Hello is read, the bytes are gone. We need to preserve the bytes and forward them along to the backend, which is expecting a proper TLS connection that starts with a Client Hello record.

What we need to do instead is “peek” at the Client Hello record, and thanks to some simple but powerful abstractions from Go’s io package, this can be done with just six lines of code:

func peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
        peekedBytes := new(bytes.Buffer)
        hello, err := readClientHello(io.TeeReader(reader, peekedBytes))
        if err != nil {
                return nil, nil, err
        }
        return hello, io.MultiReader(peekedBytes, reader), nil
}

What this code does is create a TeeReader, which is a reader that wraps another reader and writes everything that is read to a writer, which in our case is a byte buffer. We pass the TeeReader to readClientHello, so every byte read by readClientHello gets saved to our buffer. Finally, we create a MultiReader which essentially concatenates our buffer with the original reader. Reads from the MultiReader initially come out of the buffer, and when that’s exhausted, continue from the original reader. We return the MultiReader to the caller along with the ClientHelloInfo. When the caller reads from the MultiReader it will see a full TLS connection stream, starting with the Client Hello.

Now we just need to implement readClientHello. We could open up the TLS RFCs and learn how to parse a Client Hello record, but it turns out we can let crypto/tls do the work for us, thanks to a callback function in tls.Config called GetConfigForClient:

// GetConfigForClient, if not nil, is called after a ClientHello is
// received from a client.
GetConfigForClient func(*ClientHelloInfo) (*Config, error) // Go 1.8

Roughly, what we need to do is create a TLS server-side connection with a GetConfigForClient callback that saves the ClientHelloInfo passed to it. However, creating a TLS connection requires a full-blown net.Conn, and readClientHello is passed merely an io.Reader. So let’s create a type, readOnlyConn, which wraps an io.Reader and satisfies the net.Conn interface:

type readOnlyConn struct {
        reader io.Reader
}
func (conn readOnlyConn) Read(p []byte) (int, error)         { return conn.reader.Read(p) }
func (conn readOnlyConn) Write(p []byte) (int, error)        { return 0, io.ErrClosedPipe }
func (conn readOnlyConn) Close() error                       { return nil }
func (conn readOnlyConn) LocalAddr() net.Addr                { return nil }
func (conn readOnlyConn) RemoteAddr() net.Addr               { return nil }
func (conn readOnlyConn) SetDeadline(t time.Time) error      { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error  { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }

readOnlyConn forwards reads to the reader and simulates a broken pipe when written to (as if the client closed the connection before the server could reply). All other operations are a no-op.

Now we’re ready to write readClientHello:

func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
        var hello *tls.ClientHelloInfo

        err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
                GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
                        hello = new(tls.ClientHelloInfo)
                        *hello = *argHello
                        return nil, nil
                },
        }).Handshake()

        if hello == nil {
                return nil, err
        }

        return hello, nil
}

Note that Handshake always fails because the readOnlyConn is not a real connection. As long as the Client Hello is successfully read, the failure should only happen after GetConfigForClient is called, so we only care about the error if hello was never set.

Let’s put everything together to write the full handleConnection function. I’ve added deadlines (thanks, Filippo!) and a check that the SNI value ends with .internal.example.com to prevent this from being used as an open proxy. When I deploy this, I will use the DNS suffix of my home network.

func handleConnection(clientConn net.Conn) {
        defer clientConn.Close()

        if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
                log.Print(err)
                return
        }

        clientHello, clientReader, err := peekClientHello(clientConn)
        if err != nil {
                log.Print(err)
                return
        }

        if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
                log.Print(err)
                return
        }

        if !strings.HasSuffix(clientHello.ServerName, ".internal.example.com") {
                log.Print("Blocking connection to unauthorized backend")
                return
        }

        backendConn, err := net.DialTimeout("tcp", net.JoinHostPort(clientHello.ServerName, "443"), 5*time.Second)
        if err != nil {
                log.Print(err)
                return
        }
        defer backendConn.Close()

        var wg sync.WaitGroup
        wg.Add(2)

        go func() {
                io.Copy(clientConn, backendConn)
                clientConn.(*net.TCPConn).CloseWrite()
                wg.Done()
        }()
        go func() {
                io.Copy(backendConn, clientReader)
                backendConn.(*net.TCPConn).CloseWrite()
                wg.Done()
        }()

        wg.Wait()
}

Here’s the complete Go source code – just 115 lines!(Not counting copyright legalese)