Skip to content

Commit

Permalink
flagutil: make LinkFlags() to be unidirectional LinkFlag()
Browse files Browse the repository at this point in the history
  • Loading branch information
gobwas committed Sep 21, 2021
1 parent 465da5a commit 1461afe
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 99 deletions.
89 changes: 30 additions & 59 deletions flagutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,68 +553,39 @@ func SetActual(fs *flag.FlagSet, name string) {
}
}

// LinkFlags links flags named as n0 and n1 in existing flag set fs.
// If any of the flags doesn't exist LinkFlags() will create one.
func LinkFlags(fs *flag.FlagSet, n0, n1 string) {
var (
u0 string
u1 string
v0 flag.Value
v1 flag.Value
)
f0 := fs.Lookup(n0)
if f0 != nil {
v0 = f0.Value
u0 = f0.Usage
// LinkFlag links dst to be updated when src value is set.
// It panics if any of the given names doesn't exist in fs.
//
// Note that it caches the both src and dst flag.Value pointers internally, so
// it is possible to link src to dst and dst to src without infinite recursion.
// However, if any of the src or dst flag value get overwritten after
// LinkFlag() call, created link will not work properly anymore.
func LinkFlag(fs *flag.FlagSet, src, dst string) {
srcFlag := fs.Lookup(src)
if srcFlag == nil {
panic(fmt.Sprintf(
"flagutil: link flag: source flag %q must exist",
src,
))
}
f1 := fs.Lookup(n1)
if f1 != nil {
v1 = f1.Value
u1 = f1.Usage
dstFlag := fs.Lookup(dst)
if dstFlag == nil {
panic(fmt.Sprintf(
"flagutil: link flag: destination flag %q must exist",
src,
))
}

usage := mergeUsage(n0+","+n1, u0, u1)

v := value{
doSet: func(s string) (err error) {
if err == nil && v0 != nil {
err = v0.Set(s)
}
if err == nil && v1 != nil {
err = v1.Set(s)
}
var (
srcValue = srcFlag.Value
dstValue = dstFlag.Value
)
srcFlag.Value = OverrideSet(srcFlag.Value, func(s string) error {
err := srcValue.Set(s)
if err != nil {
return err
},
doIsBoolFlag: func() bool {
if v0 == nil || v1 == nil {
// Can't guess in advance.
return false
}
return isBoolValue(v0) && isBoolValue(v1)
},
doString: func() string {
if v0 == nil || v1 == nil {
// Can't guess in advance.
return ""
}
s0 := v0.String()
s1 := v1.String()
if s0 == s1 {
return s0
}
return ""
},
}
if f0 != nil {
f0.Value = v
} else {
fs.Var(v, n0, usage)
}
if f1 != nil {
f1.Value = v
} else {
fs.Var(v, n1, usage)
}
}
return dstValue.Set(s)
})
}

func mergeUsage(name, s0, s1 string) string {
Expand Down
88 changes: 48 additions & 40 deletions flagutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ func TestUnquoteUsage(t *testing.T) {
type expMode map[UnquoteUsageMode][2]string
for _, test := range []struct {
name string
flag flag.Flag
flag *flag.Flag
modes expMode
}{
{
flag: flag.Flag{
flag: &flag.Flag{
Usage: "foo `bar` baz",
},
modes: expMode{
Expand Down Expand Up @@ -149,7 +149,7 @@ func TestUnquoteUsage(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
for mode, exp := range test.modes {
t.Run(mode.String(), func(t *testing.T) {
actName, actUsage := unquoteUsage(mode, &test.flag)
actName, actUsage := unquoteUsage(mode, test.flag)
if expName := exp[0]; actName != expName {
t.Errorf("unexpected name:\n%s", cmp.Diff(expName, actName))
}
Expand Down Expand Up @@ -362,37 +362,37 @@ func isActual(fs *flag.FlagSet, name string) (actual bool) {
func TestCombineFlags(t *testing.T) {
for _, test := range []struct {
name string
flags [2]flag.Flag
exp flag.Flag
flags [2]*flag.Flag
exp *flag.Flag
panic bool
}{
{
name: "different names",
flags: [2]flag.Flag{
flags: [2]*flag.Flag{
stringFlag("foo", "def", "desc#0"),
stringFlag("bar", "def", "desc#1"),
},
panic: true,
},
{
name: "different default values",
flags: [2]flag.Flag{
flags: [2]*flag.Flag{
stringFlag("foo", "def#0", "desc#0"),
stringFlag("foo", "def#1", "desc#1"),
},
exp: stringFlag("foo", "", "desc#0 / desc#1"),
},
{
name: "basic",
flags: [2]flag.Flag{
flags: [2]*flag.Flag{
stringFlag("foo", "def", "desc#0"),
stringFlag("foo", "def", "desc#1"),
},
exp: stringFlag("foo", "def", "desc#0 / desc#1"),
},
{
name: "basic",
flags: [2]flag.Flag{
flags: [2]*flag.Flag{
stringFlag("foo", "def", "desc#0"),
stringFlag("foo", "", "desc#1"),
},
Expand All @@ -414,7 +414,7 @@ func TestCombineFlags(t *testing.T) {
}
}()
done <- flagOrPanic{
flag: CombineFlags(&test.flags[0], &test.flags[1]),
flag: CombineFlags(test.flags[0], test.flags[1]),
}
}()
x := <-done
Expand All @@ -432,65 +432,73 @@ func TestCombineFlags(t *testing.T) {
return v.String()
}),
}
if act, exp := x.flag, &test.exp; !cmp.Equal(act, exp, opts...) {
if act, exp := x.flag, test.exp; !cmp.Equal(act, exp, opts...) {
t.Errorf("unexpected flag:\n%s", cmp.Diff(exp, act, opts...))
}
exp := fmt.Sprintf("%x", rand.Int63())
if err := x.flag.Value.Set(exp); err != nil {
t.Fatalf("unexpected Set() error: %v", err)
}
for _, f := range test.flags {
if act := f.Value.String(); act != exp {
t.Errorf("unexpected flag value: %s; want %s", act, exp)
}
assertEquals(t, f, exp)
}
})
}
}

func TestLinkFlags(t *testing.T) {
func TestLinkFlag(t *testing.T) {
for _, test := range []struct {
name string
flags [2]flag.Flag
flags [2]*flag.Flag
links [2]string
}{
{
name: "basic",
flags: [2]flag.Flag{
flags: [2]*flag.Flag{
stringFlag("foo", "def#0", "desc#0"),
stringFlag("bar", "def#1", "desc#1"),
},
links: [2]string{"foo", "bar"},
},
} {
for i := 0; i < 2; i++ {
t.Run(test.name, func(t *testing.T) {
fs := flag.NewFlagSet("", flag.PanicOnError)
for _, f := range test.flags {
t.Run(test.name, func(t *testing.T) {
fs := flag.NewFlagSet("", flag.PanicOnError)
for _, f := range test.flags {
if f != nil {
fs.Var(f.Value, f.Name, f.Usage)
}
LinkFlags(fs,
test.flags[0].Name,
test.flags[1].Name,
)

exp := fmt.Sprintf("%x", rand.Int63())
fs.Set(test.flags[i].Name, exp)

for _, f := range test.flags {
if act := f.Value.String(); act != exp {
t.Errorf(
"unexpected flag %q value: %s; want %s",
f.Name, act, exp,
)
}
}
LinkFlag(fs, test.links[0], test.links[1])

// First, test that setting for src flag affects dst flag.
exp := fmt.Sprintf("%x", rand.Int63())
fs.Set(test.links[0], exp)
for _, n := range test.links {
if f := fs.Lookup(n); f != nil {
assertEquals(t, f, exp)
}
})
}
}
// Second, test that setting dst flag doesn't affect src flag.
nonExp := fmt.Sprintf("%x", rand.Int63())
fs.Set(test.flags[1].Name, nonExp)
assertEquals(t, test.flags[0], exp) // Still the same.
assertEquals(t, test.flags[1], nonExp) // Updated.
})
}
}

func assertEquals(t *testing.T, f *flag.Flag, exp string) {
if act := f.Value.String(); act != exp {
t.Errorf(
"unexpected flag %q value: %s; want %s",
f.Name, act, exp,
)
}
}

func stringFlag(name, def, desc string) flag.Flag {
func stringFlag(name, def, desc string) *flag.Flag {
fs := flag.NewFlagSet("", flag.PanicOnError)
fs.String(name, def, desc)
f := fs.Lookup(name)
return *f
return f
}

0 comments on commit 1461afe

Please sign in to comment.