diff --git a/uinit/doc.go b/uinit/doc.go new file mode 100644 index 00000000..eb238c15 --- /dev/null +++ b/uinit/doc.go @@ -0,0 +1,66 @@ +// dut manages Devices Under Test (a.k.a. DUT) from a host. +// A primary goal is allowing multiple hosts with any architecture to connect. +// +// This program was designed to be used in u-root images, as the uinit, +// or in other initramfs systems. It can not function as a standalone +// init: it assumes network is set up, for example. +// +// In this document, dut refers to this program, and DUT refers to +// Devices Under Test. Hopefully this is not too confusing, but it is +// convenient. Also, please note: DUT is plural (Devices). We don't need +// to say DUTs -- at least one is assumed. +// +// The same dut binary runs on host and DUT, in either device mode (i.e. +// on the DUT), or in some host-specific mode. The mode is chosen by +// the first non-flag argument. If there are flags specific to that mode, +// they follow that argument. +// E.g., when uinit is run on the host and we want it to enable cpu daemons +// on the DUT, we run it as follows: +// dut cpu -key ... +// the -key switch is only valid following the cpu mode argument. +// +// modes +// dut currently supports 3 modes. +// +// The first, default, mode, is "device". In device mode, dut makes an http connection +// to a dut running on a host, then starts an HTTP RPC server. +// +// The second mode is "tester". In this mode, dut calls the Welcome service, followed +// by the Reboot service. Tester can be useful, run by a shell script in a for loop, for +// ensure reboot is reliable. +// +// The third mode is "cpu". dut will direct the DUT to start a cpu service, and block until +// it exits. Flags for this service: +// pubkey: name of the public key file +// hostkey: name of the host key file +// cpuport: port on which to serve the cpu service +// +// Theory of Operation +// dut runs on the host, accepting connections from DUT, and controlling them via +// Go HTTP RPC commands. As each command is executed, its response is printed. +// Commands are: +// +// Welcome -- get a welcome message +// Argument: None +// Return: a welcome message in cowsay format: +// < welcome to DUT > +// -------------- +// \ ^__^ +// \ (oo)\_______ +// (__)\ )\/\ +// ||----w | +// || || +// +// Die -- force dut on DUT to exit +// Argument: time to sleep before exiting as a time.Duration +// Return: no return; kills the program running on DUT +// +// Reboot +// Argument: time to sleep before rebooting as a time.Duration +// +// CPU -- Start a CPU server on DUT +// Arguments: public key and host key as a []byte, service port as a string +// Returns: returns (possibly nil) error exit value of cpu server; blocks until it is done +// +// +package main diff --git a/uinit/dut.go b/uinit/dut.go new file mode 100644 index 00000000..1f9964d3 --- /dev/null +++ b/uinit/dut.go @@ -0,0 +1,189 @@ +// This is a very simple dut program. It builds into one binary to implement +// both client and server. It's just easier to see both sides of the code and test +// that way. +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "log" + "net" + "net/rpc" + "os" + "time" + + "github.com/u-root/u-root/pkg/ulog" + "golang.org/x/sys/unix" +) + +var ( + debug = flag.Bool("d", false, "Enable debug prints") + host = flag.String("host", "192.168.0.1", "hostname") + klog = flag.Bool("klog", false, "Direct all logging to klog -- depends on debug") + port = flag.String("port", "8080", "port number") + dir = flag.String("dir", ".", "directory to serve") + + // for debug + v = func(string, ...interface{}) {} +) + +func dutStart(t, host, port string) (net.Listener, error) { + ln, err := net.Listen(t, host+":"+port) + if err != nil { + log.Print(err) + return nil, err + } + log.Printf("Listening on %v at %v", ln.Addr(), time.Now()) + return ln, nil +} + +func dutAccept(l net.Listener) (net.Conn, error) { + if err := l.(*net.TCPListener).SetDeadline(time.Now().Add(3 * time.Minute)); err != nil { + return nil, err + } + c, err := l.Accept() + if err != nil { + log.Printf("Listen failed: %v at %v", err, time.Now()) + log.Print(err) + return nil, err + } + log.Printf("Accepted %v", c) + return c, nil +} + +func dutRPC(host, port string) error { + l, err := dutStart("tcp", host, port) + if err != nil { + return err + } + c, err := dutAccept(l) + if err != nil { + return err + } + cl := rpc.NewClient(c) + for _, cmd := range []struct { + call string + args interface{} + }{ + {"Command.Welcome", &RPCWelcome{}}, + {"Command.Reboot", &RPCReboot{}}, + } { + var r RPCRes + if err := cl.Call(cmd.call, cmd.args, &r); err != nil { + return err + } + fmt.Printf("%v(%v): %v\n", cmd.call, cmd.args, string(r.C)) + } + + if c, err = dutAccept(l); err != nil { + return err + } + cl = rpc.NewClient(c) + var r RPCRes + if err := cl.Call("Command.Welcome", &RPCWelcome{}, &r); err != nil { + return err + } + fmt.Printf("%v(%v): %v\n", "Command.Welcome", nil, string(r.C)) + + return nil +} + +func dutcpu(host, port, pubkey, hostkey, cpuport string) error { + var req = &RPCCPU{Port: cpuport} + var err error + + // we send the pubkey and hostkey as the value of the key, not the + // name of the file. + // TODO: maybe use ssh_config to find keys? the cpu client can do that. + // Note: the public key is not optional. That said, we do not test + // for len(*pubKey) > 0; if it is set to ""< ReadFile will return + // an error. + if req.PubKey, err = ioutil.ReadFile(pubkey); err != nil { + return fmt.Errorf("Reading pubKey:%w", err) + } + if len(hostkey) > 0 { + if req.HostKey, err = ioutil.ReadFile(hostkey); err != nil { + return fmt.Errorf("Reading hostKey:%w", err) + } + } + + l, err := dutStart("tcp", host, port) + if err != nil { + return err + } + + c, err := dutAccept(l) + if err != nil { + return err + } + + cl := rpc.NewClient(c) + + for _, cmd := range []struct { + call string + args interface{} + }{ + {"Command.Welcome", &RPCWelcome{}}, + {"Command.Welcome", &RPCWelcome{}}, + {"Command.CPU", req}, + } { + var r RPCRes + if err := cl.Call(cmd.call, cmd.args, &r); err != nil { + return err + } + fmt.Printf("%v(%v): %v\n", cmd.call, cmd.args, string(r.C)) + } + return err +} + +func main() { + // for CPU + flag.Parse() + + if *debug { + v = log.Printf + if *klog { + ulog.KernelLog.Reinit() + v = ulog.KernelLog.Printf + } + } + a := flag.Args() + if len(a) == 0 { + a = []string{"device"} + } + + os.Args = a + var err error + v("Mode is %v", a[0]) + switch a[0] { + case "tester": + err = dutRPC(*host, *port) + case "cpu": + var ( + pubKey = flag.String("pubkey", "key.pub", "public key file") + hostKey = flag.String("hostkey", "", "host key file -- usually empty") + cpuPort = flag.String("cpuport", "17010", "cpu port -- IANA value is ncpu tcp/17010") + ) + v("Parse %v", os.Args) + flag.Parse() + v("pubkey %v", *pubKey) + if err := dutcpu(*host, *port, *pubKey, *hostKey, *cpuPort); err != nil { + log.Printf("cpu service: %v", err) + } + case "device": + err = uinit(*host, *port) + // What to do after a return? Reboot I suppose. + log.Printf("Device returns with error %v", err) + if err := unix.Reboot(int(unix.LINUX_REBOOT_CMD_RESTART)); err != nil { + log.Printf("Reboot failed, not sure what to do now.") + } + default: + log.Printf("Unknown mode %v", a[0]) + } + log.Printf("We are now done ......................") + if err != nil { + log.Printf("%v", err) + os.Exit(2) + } +} diff --git a/uinit/dut_test.go b/uinit/dut_test.go new file mode 100644 index 00000000..f5990c91 --- /dev/null +++ b/uinit/dut_test.go @@ -0,0 +1,52 @@ +package main + +import ( + "net/rpc" + "testing" + "time" +) + +func TestUinit(t *testing.T) { + var tests = []struct { + c string + r interface{} + err string + }{ + {c: "Welcome", r: RPCWelcome{}}, + {c: "Reboot", r: RPCReboot{}}, + {c: "Kexec", r: RPCKexec{When: 5 * time.Second}, err: "Not yet"}, + } + l, err := dutStart("tcp", "localhost", "") + if err != nil { + t.Fatal(err) + } + + a := l.Addr() + t.Logf("listening on %v", a) + // Kick off our node. + go func() { + time.Sleep(1) + if err := uinit(a.Network(), a.String(), "17010"); err != nil { + t.Fatalf("starting uinit: got %v, want nil", err) + } + }() + + c, err := dutAccept(l) + if err != nil { + t.Fatal(err) + } + t.Logf("Connected on %v", c) + + cl := rpc.NewClient(c) + for _, tt := range tests { + t.Run(tt.c, func(t *testing.T) { + var r RPCRes + if err = cl.Call("Command."+tt.c, tt.r, &r); err != nil { + t.Fatalf("Call to %v: got %v, want nil", tt.c, err) + } + if r.Err != tt.err { + t.Errorf("%v: got %v, want %v", tt, r.Err, tt.err) + } + }) + } +} diff --git a/uinit/rpc.go b/uinit/rpc.go new file mode 100644 index 00000000..2d2aaf9d --- /dev/null +++ b/uinit/rpc.go @@ -0,0 +1,79 @@ +package main + +import ( + "fmt" + "log" + "os" + "time" + + "golang.org/x/sys/unix" +) + +type RPCRes struct { + C []byte + Err string +} + +type Command int + +type RPCWelcome struct { +} + +func (*Command) Welcome(args *RPCWelcome, r *RPCRes) error { + r.C = []byte(welcome) + r.Err = "" + log.Printf("welcome") + return nil +} + +type RPCExit struct { + When time.Duration +} + +func (*Command) Die(args *RPCExit, r *RPCRes) error { + go func() { + time.Sleep(args.When) + log.Printf("die exits") + os.Exit(0) + }() + *r = RPCRes{} + log.Printf("die returns") + return nil +} + +type RPCReboot struct { + When time.Duration +} + +func (*Command) Reboot(args *RPCReboot, r *RPCRes) error { + go func() { + time.Sleep(args.When) + if err := unix.Reboot(unix.LINUX_REBOOT_CMD_RESTART); err != nil { + log.Printf("%v\n", err) + } + }() + *r = RPCRes{} + log.Printf("reboot returns") + return nil +} + +type RPCCPU struct { + PubKey []byte + HostKey []byte + Port string +} + +func (*Command) CPU(args *RPCCPU, r *RPCRes) error { + v("CPU") + res := make(chan error) + go func(pubKey, hostKey []byte, port string) { + v("cpu serve(%q,%q,%q)", pubKey, hostKey, port) + err := serve(pubKey, hostKey, port) + v("cpu serve returns") + res <- err + }(args.PubKey, args.HostKey, args.Port) + err := <-res + *r = RPCRes{Err: fmt.Sprintf("%v", err)} + v("cpud returns") + return nil +} diff --git a/uinit/serve.go b/uinit/serve.go new file mode 100644 index 00000000..52eccb4f --- /dev/null +++ b/uinit/serve.go @@ -0,0 +1,66 @@ +// Copyright 2022 the u-root Authors. All rights reserved +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "io/ioutil" + "log" + "net" + "time" + + // We use this ssh because it implements port redirection. + // It can not, however, unpack password-protected keys yet. + "github.com/gliderlabs/ssh" // TODO: get rid of krpty + "github.com/u-root/cpu/server" + "golang.org/x/sys/unix" +) + +// hang hangs for a VERY long time. +// This aids diagnosis, else you lose all messages in the +// kernel panic as init exits. +func hang() { + log.Printf("hang") + time.Sleep(10000 * time.Second) + log.Printf("done hang") +} + +func serve(pubKey, hostKey []byte, port string) error { + if err := unix.Mount("cpu", "/tmp", "tmpfs", 0, ""); err != nil { + log.Printf("CPUD:Warning: tmpfs mount on /tmp (%v) failed. There will be no 9p mount", err) + } + + // Note that the keys are in a private mount; no need for a temp file. + if err := ioutil.WriteFile("/tmp/key.pub", pubKey, 0644); err != nil { + return fmt.Errorf("writing pubkey: %w", err) + } + if len(hostKey) > 0 { + if err := ioutil.WriteFile("/tmp/hostkey", hostKey, 0644); err != nil { + return fmt.Errorf("writing hostkey: %w", err) + } + } + + v("Kicked off startup jobs, now serve ssh") + s, err := server.New("/tmp/key.pub", "/tmp/hostkey") + if err != nil { + log.Printf(`New(%q, %q): %v != nil`, "/tmp/key.pub", "/tmp/hostkey", err) + hang() + } + v("Server is %v", s) + + ln, err := net.Listen("tcp", ":"+port) + if err != nil { + log.Printf("net.Listen(): %v != nil", err) + hang() + } + v("Listening on %v", ln.Addr()) + if err := s.Serve(ln); err != ssh.ErrServerClosed { + log.Printf("s.Daemon(): %v != %v", err, ssh.ErrServerClosed) + hang() + } + v("Daemon returns") + hang() + return nil +} diff --git a/uinit/uinit.go b/uinit/uinit.go new file mode 100644 index 00000000..1a3c77d0 --- /dev/null +++ b/uinit/uinit.go @@ -0,0 +1,62 @@ +package main + +import ( + "log" + "net" + "net/rpc" + "os" + "time" + + "github.com/cenkalti/backoff/v4" +) + +var ( + rebooting = "Rebooting!" + welcome = ` ______________ +< welcome to DUT > + -------------- + \ ^__^ + \ (oo)\_______ + (__)\ )\/\ + ||----w | + || || +` +) + +func uinit(r, p string) error { + log.Printf("here we are in uinit") + log.Printf("UINIT uid is %d", os.Getuid()) + + na := r + ":" + p + log.Printf("Now dial %v", na) + b := backoff.NewExponentialBackOff() + // We'll go at it for 5 minutes, then reboot. + b.MaxElapsedTime = 5 * time.Minute + + var c net.Conn + f := func() error { + nc, err := net.Dial("tcp", na) + if err != nil { + log.Printf("Dial went poorly") + return err + } + c = nc + return nil + } + if err := backoff.Retry(f, b); err != nil { + return err + } + log.Printf("Start the RPC server") + var Cmd Command + s := rpc.NewServer() + log.Printf("rpc server is %v", s) + if err := s.Register(&Cmd); err != nil { + log.Printf("register failed: %v", err) + return err + } + log.Printf("Serve and protect") + s.ServeConn(c) + log.Printf("And uinit is all done.") + return nil + +}