Skip to content

Commit

Permalink
Merge pull request #3641 from telepresenceio/thallgren/connect-extens…
Browse files Browse the repository at this point in the history
…ion-point

Add extension point for the CLI connection request.
  • Loading branch information
thallgren authored Jul 4, 2024
2 parents a185219 + 0754225 commit cd4e4e5
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 77 deletions.
14 changes: 7 additions & 7 deletions pkg/client/cli/cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func run(cmd *cobra.Command, _ []string) error {
return err
}
sis[i], err = getStatusInfo(udCtx, info)
_ = daemon.GetUserClient(udCtx).Conn.Close()
_ = daemon.GetUserClient(udCtx).Close()
if err != nil {
return err
}
Expand Down Expand Up @@ -200,7 +200,7 @@ var GetStatusInfo = func(ctx context.Context) (ioutil.WriterTos, error) {
// GetTrafficManagerStatusExtras may return an extended struct
//
//nolint:gochecknoglobals // extension point
var GetTrafficManagerStatusExtras = func(context.Context, *daemon.UserClient) ioutil.KeyValueProvider {
var GetTrafficManagerStatusExtras = func(context.Context, daemon.UserClient) ioutil.KeyValueProvider {
return nil
}

Expand Down Expand Up @@ -254,10 +254,10 @@ func getStatusInfo(ctx context.Context, di *daemon.Info) (*StatusInfo, error) {
us := &wt.UserDaemon
us.InstallID = scout.InstallID(ctx)
us.Running = true
us.Version = userD.Version.String()
us.versionName = userD.Name
us.Executable = userD.Executable
us.Name = userD.DaemonID.Name
us.Version = userD.Semver().String()
us.versionName = userD.Name()
us.Executable = userD.Executable()
us.Name = userD.DaemonID().Name

if userD.Containerized() {
us.InDocker = true
Expand All @@ -266,7 +266,7 @@ func getStatusInfo(ctx context.Context, di *daemon.Info) (*StatusInfo, error) {
us.Hostname = di.Hostname
us.ExposedPorts = di.ExposedPorts
}
us.ContainerNetwork = "container:" + userD.DaemonID.ContainerName()
us.ContainerNetwork = "container:" + userD.DaemonID().ContainerName()
if us.versionName == "" {
us.versionName = "Daemon"
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/client/cli/cmd/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func addDaemonVersions(ctx context.Context, kvf *ioutil.KeyValueFormatter) {
}

if userD != nil {
kvf.Add(userD.Name, "v"+userD.Version.String())
kvf.Add(userD.Name(), "v"+userD.Semver().String())
vi, err := managerVersion(ctx)
switch {
case err == nil:
Expand Down Expand Up @@ -106,8 +106,8 @@ func printVersion(cmd *cobra.Command, _ []string) error {
}
addDaemonVersions(udCtx, subKvf)
ud := daemon.GetUserClient(udCtx)
kvf.Add("Connection "+ud.DaemonID.Name, "\n"+subKvf.String())
ud.Conn.Close()
kvf.Add("Connection "+ud.DaemonID().Name, "\n"+subKvf.String())
_ = ud.Close()
}
} else {
addDaemonVersions(ctx, kvf)
Expand Down
30 changes: 11 additions & 19 deletions pkg/client/cli/connect/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func quitHostConnector(ctx context.Context) {
}
ud := daemon.GetUserClient(udCtx)
_, _ = ud.Quit(ctx, &emptypb.Empty{})
_ = ud.Conn.Close()
_ = ud.Close()
_ = socket.WaitUntilVanishes("user daemon", socket.UserDaemonPath(ctx), 5*time.Second)

// User daemon is responsible for killing the root daemon, but we kill it here too to cater for
Expand All @@ -94,7 +94,7 @@ func quitDockerDaemons(ctx context.Context) {
}
ud := daemon.GetUserClient(udCtx)
_, _ = ud.Quit(ctx, &emptypb.Empty{})
_ = ud.Conn.Close()
_ = ud.Close()
}
if err = daemon.WaitUntilAllVanishes(ctx, 5*time.Second); err != nil {
dlog.Error(ctx, err)
Expand Down Expand Up @@ -319,23 +319,15 @@ func getConnectorVersion(ctx context.Context, cc connector.ConnectorClient) (*co
}

func newUserDaemon(ctx context.Context, conn *grpc.ClientConn, daemonID *daemon.Identifier) (context.Context, error) {
cc := connector.NewConnectorClient(conn)
vi, err := getConnectorVersion(ctx, cc)
vi, err := getConnectorVersion(ctx, connector.NewConnectorClient(conn))
if err != nil {
return ctx, err
}
v, err := semver.Parse(strings.TrimPrefix(vi.Version, "v"))
if err != nil {
return ctx, fmt.Errorf("unable to parse version obtained from connector daemon: %w", err)
}
ctx = daemon.WithUserClient(ctx, &daemon.UserClient{
ConnectorClient: cc,
Conn: conn,
DaemonID: daemonID,
Version: v,
Name: vi.Name,
Executable: vi.Executable,
})
ctx = daemon.WithUserClient(ctx, daemon.NewUserClientFunc(conn, daemonID, v, vi.Name, vi.Executable))
return ctx, nil
}

Expand Down Expand Up @@ -376,25 +368,25 @@ func EnsureSession(ctx context.Context, useLine string, required bool) (context.
return daemon.WithSession(ctx, s), nil
}

func connectSession(ctx context.Context, useLine string, userD *daemon.UserClient, request *daemon.Request, required bool) (*daemon.Session, error) {
func connectSession(ctx context.Context, useLine string, userD daemon.UserClient, request *daemon.Request, required bool) (*daemon.Session, error) {
var ci *connector.ConnectInfo
var err error
if userD.Containerized() && !proc.RunningInContainer() {
patcher.AnnotateConnectRequest(&request.ConnectRequest, docker.TpCache, userD.DaemonID.KubeContext)
patcher.AnnotateConnectRequest(&request.ConnectRequest, docker.TpCache, userD.DaemonID().KubeContext)
}
session := func(ci *connector.ConnectInfo, started bool) *daemon.Session {
// Update the request from the connect info.
request.KubeFlags = ci.KubeFlags
request.ManagerNamespace = ci.ManagerNamespace
request.Name = ci.ConnectionName
userD.DaemonID = &daemon.Identifier{
userD.SetDaemonID(&daemon.Identifier{
Name: ci.ConnectionName,
KubeContext: ci.ClusterContext,
Namespace: ci.Namespace,
Containerized: userD.Containerized(),
}
})
return &daemon.Session{
UserClient: *userD,
UserClient: userD,
Info: ci,
Started: started,
}
Expand Down Expand Up @@ -482,7 +474,7 @@ func connectSession(ctx context.Context, useLine string, userD *daemon.UserClien
}

if !userD.Containerized() {
daemonID := userD.DaemonID
daemonID := userD.DaemonID()
err = daemon.SaveInfo(ctx,
&daemon.Info{
InDocker: false,
Expand All @@ -498,7 +490,7 @@ func connectSession(ctx context.Context, useLine string, userD *daemon.UserClien
}
if ci, err = userD.Connect(ctx, &request.ConnectRequest); err != nil {
if !userD.Containerized() {
_ = daemon.DeleteInfo(ctx, userD.DaemonID.InfoFileName())
_ = daemon.DeleteInfo(ctx, userD.DaemonID().InfoFileName())
}
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/client/cli/connect/version_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var validPrerelRx = regexp.MustCompile(`^[a-z]+\.\d+$`)
func versionCheck(ctx context.Context, daemonBinary string) error {
// Ensure that the already running daemons have the correct version
userD := daemon.GetUserClient(ctx)
uv := userD.Version
uv := userD.Semver()
if userD.Containerized() {
// The user-daemon is remote (in a docker container, most likely). Compare the major, minor, and patch. Only
// compare pre-release if it's rc.X or test.X, and don't check if the binaries match.
Expand All @@ -38,9 +38,9 @@ func versionCheck(ctx context.Context, daemonBinary string) error {
return errcat.User.Newf("version mismatch. Client %s != user daemon %s, please run 'telepresence quit -s' and reconnect",
version.Version, uv)
}
if daemonBinary != "" && userD.Executable != daemonBinary {
if daemonBinary != "" && userD.Executable() != daemonBinary {
return errcat.User.Newf("executable mismatch. Connector using %s, configured to use %s, please run 'telepresence quit -s' and reconnect",
userD.Executable, daemonBinary)
userD.Executable(), daemonBinary)
}
vr, err := userD.RootDaemonVersion(ctx, &empty.Empty{})
if err == nil && version.Version != vr.Version {
Expand Down
76 changes: 62 additions & 14 deletions pkg/client/cli/daemon/userd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package daemon

import (
"context"
"io"
"strconv"
"strings"

Expand All @@ -11,13 +12,32 @@ import (
"github.com/telepresenceio/telepresence/rpc/v2/connector"
)

type UserClient struct {
type UserClient interface {
connector.ConnectorClient
Conn *grpc.ClientConn
DaemonID *Identifier
Version semver.Version
Executable string
Name string
io.Closer
Conn() *grpc.ClientConn
Containerized() bool
DaemonPort() int
DaemonID() *Identifier
Executable() string
Name() string
Semver() semver.Version
SetDaemonID(*Identifier)
}

type userClient struct {
connector.ConnectorClient
conn *grpc.ClientConn
daemonID *Identifier
version semver.Version
executable string
name string
}

var NewUserClientFunc = NewUserClient //nolint:gochecknoglobals // extension point

func NewUserClient(conn *grpc.ClientConn, daemonID *Identifier, version semver.Version, name string, executable string) UserClient {
return &userClient{ConnectorClient: connector.NewConnectorClient(conn), conn: conn, daemonID: daemonID, version: version, name: name, executable: executable}
}

type Session struct {
Expand All @@ -28,14 +48,14 @@ type Session struct {

type userDaemonKey struct{}

func GetUserClient(ctx context.Context) *UserClient {
if ud, ok := ctx.Value(userDaemonKey{}).(*UserClient); ok {
func GetUserClient(ctx context.Context) UserClient {
if ud, ok := ctx.Value(userDaemonKey{}).(UserClient); ok {
return ud
}
return nil
}

func WithUserClient(ctx context.Context, ud *UserClient) context.Context {
func WithUserClient(ctx context.Context, ud UserClient) context.Context {
return context.WithValue(ctx, userDaemonKey{}, ud)
}

Expand All @@ -52,13 +72,37 @@ func WithSession(ctx context.Context, s *Session) context.Context {
return context.WithValue(ctx, sessionKey{}, s)
}

func (ud *UserClient) Containerized() bool {
return ud.DaemonID.Containerized
func (u *userClient) Close() error {
return u.conn.Close()
}

func (u *userClient) Conn() *grpc.ClientConn {
return u.conn
}

func (u *userClient) Containerized() bool {
return u.daemonID.Containerized
}

func (ud *UserClient) DaemonPort() int {
if ud.DaemonID.Containerized {
addr := ud.Conn.Target()
func (u *userClient) DaemonID() *Identifier {
return u.daemonID
}

func (u *userClient) Executable() string {
return u.executable
}

func (u *userClient) Name() string {
return u.name
}

func (u *userClient) Semver() semver.Version {
return u.version
}

func (u *userClient) DaemonPort() int {
if u.daemonID.Containerized {
addr := u.conn.Target()
if lc := strings.LastIndexByte(addr, ':'); lc >= 0 {
if port, err := strconv.Atoi(addr[lc+1:]); err == nil {
return port
Expand All @@ -67,3 +111,7 @@ func (ud *UserClient) DaemonPort() int {
}
return -1
}

func (u *userClient) SetDaemonID(daemonID *Identifier) {
u.daemonID = daemonID
}
2 changes: 1 addition & 1 deletion pkg/client/cli/intercept/docker_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (s *state) startInDocker(ctx context.Context, name, envFile string, args []
ourArgs = append(ourArgs, "-v", fmt.Sprintf("%s:%s", s.mountPoint, dockerMount))
}
} else {
daemonName := ud.DaemonID.ContainerName()
daemonName := ud.DaemonID().ContainerName()
ourArgs = append(ourArgs, "--network", "container:"+daemonName)

if !(s.mountDisabled || s.info == nil) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/client/cli/intercept/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ func (s *state) runCommand(ctx context.Context) error {
// Ensure that the intercept handler is stopped properly if the daemon quits
procCtx, cancel := context.WithCancel(ctx)
go func() {
if err := daemon.CancelWhenRmFromCache(procCtx, cancel, ud.DaemonID.InfoFileName()); err != nil {
if err := daemon.CancelWhenRmFromCache(procCtx, cancel, ud.DaemonID().InfoFileName()); err != nil {
dlog.Error(ctx)
}
}()
Expand Down
35 changes: 17 additions & 18 deletions pkg/client/userd/daemon/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (s *service) callCtx(ctx context.Context, name string) context.Context {
return dgroup.WithGoroutineName(ctx, fmt.Sprintf("/%s-%d", name, num))
}

func (s *service) logCall(c context.Context, callName string, f func(context.Context)) {
func (s *service) LogCall(c context.Context, callName string, f func(context.Context)) {
c = s.callCtx(c, callName)
dlog.Debug(c, "called")
defer dlog.Debug(c, "returned")
Expand All @@ -77,7 +77,7 @@ func (s *service) FuseFTPError() error {
}

func (s *service) WithSession(c context.Context, callName string, f func(context.Context, userd.Session) error) (err error) {
s.logCall(c, callName, func(_ context.Context) {
s.LogCall(c, callName, func(_ context.Context) {
if atomic.LoadInt32(&s.sessionQuitting) != 0 {
err = status.Error(codes.Canceled, "session cancelled")
return
Expand Down Expand Up @@ -116,26 +116,25 @@ func (s *service) Version(_ context.Context, _ *empty.Empty) (*common.VersionInf
}, nil
}

func (s *service) Connect(ctx context.Context, cr *rpc.ConnectRequest) (result *rpc.ConnectInfo, err error) {
s.logCall(ctx, "Connect", func(c context.Context) {
select {
case <-ctx.Done():
err = status.Error(codes.Unavailable, ctx.Err().Error())
return
case s.connectRequest <- cr:
}
type crImpl struct {
*rpc.ConnectRequest
}

select {
case <-ctx.Done():
err = status.Error(codes.Unavailable, ctx.Err().Error())
case result = <-s.connectResponse:
func (c crImpl) Request() *rpc.ConnectRequest {
return c.ConnectRequest
}

func (s *service) Connect(ctx context.Context, cr *rpc.ConnectRequest) (result *rpc.ConnectInfo, err error) {
s.LogCall(ctx, "Connect", func(c context.Context) {
if err = s.PostConnectRequest(ctx, crImpl{ConnectRequest: cr}); err == nil {
result, err = s.ReadConnectResponse(ctx)
}
})
return result, err
}

func (s *service) Disconnect(ctx context.Context, ex *empty.Empty) (*empty.Empty, error) {
s.logCall(ctx, "Disconnect", func(ctx context.Context) {
s.LogCall(ctx, "Disconnect", func(ctx context.Context) {
s.cancelSession()
_ = s.withRootDaemon(ctx, func(ctx context.Context, rd daemon.DaemonClient) error {
_, err := rd.Disconnect(ctx, ex)
Expand All @@ -146,7 +145,7 @@ func (s *service) Disconnect(ctx context.Context, ex *empty.Empty) (*empty.Empty
}

func (s *service) Status(ctx context.Context, ex *empty.Empty) (result *rpc.ConnectInfo, err error) {
s.logCall(ctx, "Status", func(c context.Context) {
s.LogCall(ctx, "Status", func(c context.Context) {
s.sessionLock.RLock()
defer s.sessionLock.RUnlock()
if s.session == nil {
Expand Down Expand Up @@ -391,7 +390,7 @@ func (s *service) GatherLogs(ctx context.Context, request *rpc.LogsRequest) (res
}

func (s *service) SetLogLevel(ctx context.Context, request *rpc.LogLevelRequest) (result *empty.Empty, err error) {
s.logCall(ctx, "SetLogLevel", func(c context.Context) {
s.LogCall(ctx, "SetLogLevel", func(c context.Context) {
mrq := &manager.LogLevelRequest{
LogLevel: request.LogLevel,
Duration: request.Duration,
Expand Down Expand Up @@ -432,7 +431,7 @@ func (s *service) SetLogLevel(ctx context.Context, request *rpc.LogLevelRequest)
}

func (s *service) Quit(ctx context.Context, ex *empty.Empty) (*empty.Empty, error) {
s.logCall(ctx, "Quit", func(c context.Context) {
s.LogCall(ctx, "Quit", func(c context.Context) {
s.sessionLock.RLock()
defer s.sessionLock.RUnlock()
s.cancelSessionReadLocked()
Expand Down
Loading

0 comments on commit cd4e4e5

Please sign in to comment.