8000 feat: add port-forward subcommand by deansheather · Pull Request #1350 · coder/coder · GitHub
[go: up one dir, main page]

Skip to content

feat: add port-forward subcommand #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 18, 2022
Prev Previous commit
Next Next commit
chore: fix lint errors
  • Loading branch information
deansheather committed May 17, 2022
commit 6dfd2f689e9e073375037be404175b2f925555de
20 changes: 15 additions & 5 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -735,20 +735,30 @@ func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
defer c1.Close()
defer c2.Close()

ctx, cancel := context.WithCancel(ctx)

var wg sync.WaitGroup
copyFunc := func(dst io.WriteCloser, src io.Reader) {
defer cancel()
defer wg.Done()
_, _ = io.Copy(dst, src)
}

wg.Add(2)
go copyFunc(c1, c2)
go copyFunc(c2, c1)

<-ctx.Done()
// Convert waitgroup to a channel so we can also wait on the context.
done := make(chan struct{})
go func() {
defer close(done)
wg.Wait()
}()

select {
case <-ctx.Done():
case <-done:
}
}

// ExpandPath expands the tilde at the beggining of a path to the current user's
// ExpandPath expands the tilde at the beginning of a path to the current user's
// home directory and returns a full absolute path.
func ExpandPath(in string) (string, error) {
usr, err := user.Current()
Expand Down
35 changes: 17 additions & 18 deletions cli/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/spf13/cobra"
"golang.org/x/xerrors"

"github.com/coder/coder/agent"
coderagent "github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/database"
Expand Down Expand Up @@ -142,15 +141,15 @@ func portForward() *cobra.Command {
case <-ctx.Done():
closeErr = ctx.Err()
case <-sigs:
fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections")
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections")
closeErr = xerrors.New("signal received")
}

cancel()
closeAllListeners()
}()

fmt.Fprintln(cmd.OutOrStderr(), "Ready!")
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!")
wg.Wait()
return closeErr
},
Expand All @@ -163,8 +162,8 @@ func portForward() *cobra.Command {
return cmd
}

func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *coderagent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)

var (
l net.Listener
Expand All @@ -183,7 +182,7 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.C
return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err)
}

l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{
l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{ //nolint:ineffassign
IP: net.ParseIP(host),
Port: portInt,
})
Expand All @@ -202,16 +201,16 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *agent.C
for {
netConn, err := l.Accept()
if err != nil {
fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err)
fmt.Fprintln(cmd.OutOrStderr(), "Killing listener")
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err)
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener")
return
}

go func(netConn net.Conn) {
defer netConn.Close()
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
if err != nil {
fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
return
}
defer remoteConn.Close()
Expand All @@ -232,10 +231,10 @@ type portForwardSpec struct {
dialAddress string // <ip>:<port> or path
}

func parsePortForwards(tcp, udp, unix []string) ([]portForwardSpec, error) {
func parsePortForwards(tcpSpecs, udpSpecs, unixSpecs []string) ([]portForwardSpec, error) {
specs := []portForwardSpec{}

for _, spec := range tcp {
for _, spec := range tcpSpecs {
local, remote, err := parsePortPort(spec)
if err != nil {
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
Expand All @@ -249,7 +248,7 @@ func parsePortForwards(tcp, udp, unix []string) ([]portForwardSpec, error) {
})
}

for _, spec := range udp {
for _, spec := range udpSpecs {
local, remote, err := parsePortPort(spec)
if err != nil {
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
Expand All @@ -263,7 +262,7 @@ func parsePortForwards(tcp, udp, unix []string) ([]portForwardSpec, error) {
})
}

for _, specStr := range unix {
for _, specStr := range unixSpecs {
localPath, localTCP, remotePath, err := parseUnixUnix(specStr)
if err != nil {
return nil, xerrors.Errorf("failed to parse Unix port-forward specification %q: %w", specStr, err)
Expand Down Expand Up @@ -312,15 +311,15 @@ func parsePort(in string) (uint16, error) {
}

func parseUnixPath(in string) (string, error) {
path, err := agent.ExpandPath(strings.TrimSpace(in))
path, err := coderagent.ExpandPath(strings.TrimSpace(in))
if err != nil {
return "", xerrors.Errorf("tidy path %q: %w", in, err)
}

return path, nil
}

func parsePortPort(in string) (uint16, uint16, error) {
func parsePortPort(in string) (local uint16, remote uint16, err error) {
parts := strings.Split(in, ":")
if len(parts) > 2 {
return 0, 0, xerrors.Errorf("invalid port specification %q", in)
Expand All @@ -330,16 +329,16 @@ func parsePortPort(in string) (uint16, uint16, error) {
parts = append(parts, parts[0])
}

local, err := parsePort(parts[0])
local, err = parsePort(parts[0])
if err != nil {
return 0, 0, xerrors.Errorf("parse local port from %q: %w", in, err)
}
remote, err := parsePort(parts[1])
remote, err = parsePort(parts[1])
if err != nil {
return 0, 0, xerrors.Errorf("parse remote port from %q: %w", in, err)
}

return uint16(local), uint16(remote), nil
return local, remote, nil
}

func parsePortOrUnixPath(in string) (string, uint16, error) {
Expand Down
8 changes: 6 additions & 2 deletions cli/portforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ func TestPortForward(t *testing.T) {
t.Parallel()

t.Run("None", func(t *testing.T) {
t.Parallel()

client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)

Expand Down Expand Up @@ -138,7 +140,7 @@ func TestPortForward(t *testing.T) {
},
}

for _, c := range cases {
for _, c := range cases { //nolint:paralleltest // the `c := c` confuses the linter
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -346,7 +348,9 @@ func TestPortForward(t *testing.T) {
for i, a := range dials {
c, err := d.DialContext(ctx, a.network, a.addr)
require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1)
defer c.Close()
t.Cleanup(func() {
_ = c.Close()
})
conns[i] = c
}

Expand Down
18 changes: 12 additions & 6 deletions cli/ssh.go
4E8F
Original file line number Diff line number Diff line change
Expand Up @@ -166,29 +166,35 @@ func getWorkspaceAndAgent(ctx context.Context, client *codersdk.Client, orgID uu
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q has no agents", workspace.Name)
}

var agent *codersdk.WorkspaceAgent
var (
// We can't use a pointer because linters are mad about using pointers
// from loop variables
agent codersdk.WorkspaceAgent
agentOK bool
)
if len(workspaceParts) >= 2 {
for _, otherAgent := range agents {
if otherAgent.Name != workspaceParts[1] {
continue
}
agent = &otherAgent
agent = otherAgent
agentOK = true
break
}

if agent == nil {
if !agentOK {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("agent not found by name %q", workspaceParts[1])
}
}

if agent == nil {
if !agentOK {
if len(agents) > 1 {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent")
}
agent = &agents[0]
agent = agents[0]
}

return workspace, *agent, nil
return workspace, agent, nil
}

type stdioConn struct {
Expand Down
0