diff --git a/go.mod b/go.mod index 8e4bf722..cdb7d65c 100644 --- a/go.mod +++ b/go.mod @@ -3,18 +3,25 @@ module sigs.k8s.io/krew go 1.12 require ( + github.com/dsnet/compress v0.0.1 // indirect github.com/fatih/color v1.7.0 + github.com/frankban/quicktest v1.6.0 // indirect github.com/gogo/protobuf v1.2.1 // indirect github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b - github.com/google/go-cmp v0.3.0 + github.com/golang/snappy v0.0.1 // indirect + github.com/google/go-cmp v0.3.1 github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/mattn/go-colorable v0.1.2 // indirect github.com/mattn/go-isatty v0.0.8 + github.com/mholt/archiver v3.1.1+incompatible + github.com/nwaples/rardecode v1.0.0 // indirect + github.com/pierrec/lz4 v2.3.0+incompatible // indirect github.com/pkg/errors v0.8.0 github.com/sahilm/fuzzy v0.0.5 github.com/spf13/cobra v0.0.3 github.com/spf13/pflag v1.0.3 + github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect golang.org/x/net v0.0.0-20190628185345-da137c7871d7 // indirect golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index 3362c142..438fa71d 100644 --- a/go.sum +++ b/go.sum @@ -6,11 +6,16 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM= +github.com/dsnet/compress v0.0.1 h1:PlZu0n3Tuv04TzpfPbrnI0HW/YwodEXDS+oPKahKF0Q= +github.com/dsnet/compress v0.0.1/go.mod h1:Aw8dCMJ7RioblQeTqt88akK31OvO8Dhf5JflhBbQEHo= +github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/frankban/quicktest v1.6.0 h1:Cd62nl66vQsx8Uv1t8M0eICyxIwZG7MxiAOrdnnUSW0= +github.com/frankban/quicktest v1.6.0/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680 h1:ZktWZesgun21uEDrwW7iEV1zPCGQldM2atlJZ3TdvVM= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= @@ -27,8 +32,12 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/protobuf v0.0.0-20161109072736-4bd1920723d7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -44,6 +53,8 @@ github.com/json-iterator/go v1.1.6 h1:MrUvLMLTMxbqFJ9kzlvat/rYZqZnW3u4wkLzWTaFwK github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -56,6 +67,8 @@ github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mholt/archiver v3.1.1+incompatible h1:1dCVxuqs0dJseYEhi5pl7MYPH9zDa1wBi7mF09cbNkU= +github.com/mholt/archiver v3.1.1+incompatible/go.mod h1:Dh2dOXnSdiLxRiPoVfIr/fI1TwETms9B8CTWfeh7ROU= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180320133207-05fbef0ca5da/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -63,11 +76,15 @@ github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9 github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= +github.com/nwaples/rardecode v1.0.0 h1:r7vGuS5akxOnR4JQSkko62RJ1ReCMXxQRPtxsiFMBOs= +github.com/nwaples/rardecode v1.0.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0= github.com/onsi/ginkgo v0.0.0-20170829012221-11459a886d9c/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pierrec/lz4 v2.3.0+incompatible h1:CZzRn4Ut9GbUkHlQ7jqBXeZQV41ZSKWFc302ZU6lUTk= +github.com/pierrec/lz4 v2.3.0+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -84,6 +101,10 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/ulikunitz/xz v0.5.6 h1:jGHAfXawEGZQ3blwU5wnWKQJvAraT7Ftq9EXjnXYgt8= +github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= +github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= +github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/pkg/download/downloader.go b/pkg/download/downloader.go index 55b8a049..df4204cc 100644 --- a/pkg/download/downloader.go +++ b/pkg/download/downloader.go @@ -17,145 +17,35 @@ package download import ( "archive/tar" "archive/zip" - "bytes" - "compress/gzip" "io" "io/ioutil" "net/http" "os" - "path/filepath" "strings" "github.com/golang/glog" + "github.com/mholt/archiver" "github.com/pkg/errors" ) // download gets a file from the internet in memory and writes it content // to a Verifier. -func download(url string, verifier Verifier, fetcher Fetcher) (io.ReaderAt, int64, error) { +func download(url string, verifier Verifier, fetcher Fetcher) ([]byte, error) { glog.V(2).Infof("Fetching %q", url) body, err := fetcher.Get(url) if err != nil { - return nil, 0, errors.Wrapf(err, "could not download %q", url) + return nil, errors.Wrapf(err, "could not download %q", url) } defer body.Close() glog.V(3).Infof("Reading download data into memory") data, err := ioutil.ReadAll(io.TeeReader(body, verifier)) if err != nil { - return nil, 0, errors.Wrap(err, "could not read download content") + return nil, errors.Wrap(err, "could not read download content") } glog.V(2).Infof("Read %d bytes of download data into memory", len(data)) - return bytes.NewReader(data), int64(len(data)), verifier.Verify() -} - -// extractZIP extracts a zip file into the target directory. -func extractZIP(targetDir string, read io.ReaderAt, size int64) error { - glog.V(4).Infof("Extracting download zip to %q", targetDir) - zipReader, err := zip.NewReader(read, size) - if err != nil { - return err - } - - for _, f := range zipReader.File { - if err := suspiciousPath(f.Name); err != nil { - return err - } - - path := filepath.Join(targetDir, filepath.FromSlash(f.Name)) - if f.FileInfo().IsDir() { - if err := os.MkdirAll(path, f.Mode()); err != nil { - return errors.Wrap(err, "can't create directory tree") - } - continue - } - - src, err := f.Open() - if err != nil { - return errors.Wrap(err, "could not open inflating zip file") - } - - dst, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, f.Mode()) - if err != nil { - src.Close() - return errors.Wrap(err, "can't create file in zip destination dir") - } - close := func() { - src.Close() - dst.Close() - } - - if _, err := io.Copy(dst, src); err != nil { - close() - return errors.Wrap(err, "can't copy content to zip destination file") - } - close() - } - - return nil -} - -// extractTARGZ extracts a gzipped tar file into the target directory. -func extractTARGZ(targetDir string, at io.ReaderAt, size int64) error { - glog.V(4).Infof("tar: extracting to %q", targetDir) - in := io.NewSectionReader(at, 0, size) - - gzr, err := gzip.NewReader(in) - if err != nil { - return errors.Wrap(err, "failed to create gzip reader") - } - defer gzr.Close() - - tr := tar.NewReader(gzr) - for { - hdr, err := tr.Next() - if err == io.EOF { - break - } - if err != nil { - return errors.Wrap(err, "tar extraction error") - } - glog.V(4).Infof("tar: processing %q (type=%d, mode=%s)", hdr.Name, hdr.Typeflag, os.FileMode(hdr.Mode)) - // see https://golang.org/cl/78355 for handling pax_global_header - if hdr.Name == "pax_global_header" { - glog.V(4).Infof("tar: skipping pax_global_header file") - continue - } - - if err := suspiciousPath(hdr.Name); err != nil { - return err - } - - path := filepath.Join(targetDir, filepath.FromSlash(hdr.Name)) - switch hdr.Typeflag { - case tar.TypeDir: - if err := os.MkdirAll(path, os.FileMode(hdr.Mode)); err != nil { - return errors.Wrap(err, "failed to create directory from tar") - } - case tar.TypeReg: - dir := filepath.Dir(path) - glog.V(4).Infof("tar: ensuring parent dirs exist for regular file, dir=%s", dir) - if err := os.MkdirAll(dir, 0755); err != nil { - return errors.Wrap(err, "failed to create directory for tar") - } - f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, os.FileMode(hdr.Mode)) - if err != nil { - return errors.Wrapf(err, "failed to create file %q", path) - } - close := func() { f.Close() } - if _, err := io.Copy(f, tr); err != nil { - close() - return errors.Wrapf(err, "failed to copy %q from tar into file", hdr.Name) - } - close() - default: - return errors.Errorf("unable to handle file type %d for %q in tar", hdr.Typeflag, hdr.Name) - } - glog.V(4).Infof("tar: processed %q", hdr.Name) - } - glog.V(4).Infof("tar extraction to %s complete", targetDir) - return nil + return data, verifier.Verify() } func suspiciousPath(path string) error { @@ -170,44 +60,38 @@ func suspiciousPath(path string) error { return nil } -func detectMIMEType(at io.ReaderAt) (string, error) { - buf := make([]byte, 512) - n, err := at.ReadAt(buf, 0) - if err != nil && err != io.EOF { - return "", errors.Wrap(err, "failed to read first 512 bytes") - } - if n < 512 { - glog.V(5).Infof("Did only read %d of 512 bytes to determine the file type", n) - } +func isSuspiciousArchive(path string) error { + return archiver.Walk(path, func(f archiver.File) error { + switch h := f.Header.(type) { + case *tar.Header: + return suspiciousPath(h.Name) + case zip.FileHeader: + return suspiciousPath(h.Name) + default: + return errors.Errorf("Unknow header type: %T", h) + } + }) +} +func detectMIMEType(data []byte) string { + n := 512 + if l := len(data); l < n { + n = l + } // Cut off mime extra info beginning with ';' i.e: // "text/plain; charset=utf-8" should result in "text/plain". - return strings.Split(http.DetectContentType(buf[:n]), ";")[0], nil -} - -type extractor func(targetDir string, read io.ReaderAt, size int64) error - -var defaultExtractors = map[string]extractor{ - "application/zip": extractZIP, - "application/x-gzip": extractTARGZ, + return strings.Split(http.DetectContentType(data[:n]), ";")[0] } -func extractArchive(dst string, at io.ReaderAt, size int64) error { - // TODO(ahmetb) This package is not architected well, this method should not - // be receiving this many args. Primary problem is at GetInsecure and - // GetWithSha256 methods that embed extraction in them, which is orthogonal. - - t, err := detectMIMEType(at) - if err != nil { - return errors.Wrap(err, "failed to determine content type") +func extensionFromMIME(mime string) (string, error) { + switch mime { + case "application/zip": + return "zip", nil + case "application/x-gzip": + return "tar.gz", nil + default: + return "", errors.Errorf("unknown mime type to extract: %q", mime) } - glog.V(4).Infof("detected %q file type", t) - exf, ok := defaultExtractors[t] - if !ok { - return errors.Errorf("mime type %q for downloaded file is not a supported archive format", t) - } - return errors.Wrap(exf(dst, at, size), "failed to extract file") - } // Downloader is responsible for fetching, verifying and extracting a binary. @@ -227,9 +111,28 @@ func NewDownloader(v Verifier, f Fetcher) Downloader { // Get pulls the uri and verifies it. On success, the download gets extracted // into dst. func (d Downloader) Get(uri, dst string) error { - body, size, err := download(uri, d.verifier, d.fetcher) + data, err := download(uri, d.verifier, d.fetcher) + if err != nil { + return err + } + extension, err := extensionFromMIME(detectMIMEType(data)) if err != nil { return err } - return extractArchive(dst, body, size) + + f, err := ioutil.TempFile("", "plugin.*."+extension) + if err != nil { + return errors.Wrap(err, "failed to create temp file to write") + } + defer os.Remove(f.Name()) + if n, err := f.Write(data); err != nil { + return errors.Wrap(err, "failed to write temp download file") + } else if n != len(data) { + return errors.Errorf("failed to write whole download archive") + } + + if err := isSuspiciousArchive(f.Name()); err != nil { + return err + } + return archiver.Unarchive(f.Name(), dst) } diff --git a/pkg/download/downloader_test.go b/pkg/download/downloader_test.go index 0469a2c3..0981ed22 100644 --- a/pkg/download/downloader_test.go +++ b/pkg/download/downloader_test.go @@ -27,6 +27,7 @@ import ( "strings" "testing" + "github.com/mholt/archiver" "github.com/pkg/errors" "sigs.k8s.io/krew/pkg/testutil" @@ -66,13 +67,7 @@ func Test_extractZIP(t *testing.T) { tmpDir, cleanup := testutil.NewTempDir(t) defer cleanup() - zipReader, err := os.Open(zipSrc) - if err != nil { - t.Fatal(err) - } - defer zipReader.Close() - stat, _ := zipReader.Stat() - if err := extractZIP(tmpDir.Root(), zipReader, stat.Size()); err != nil { + if err := archiver.Unarchive(zipSrc, tmpDir.Root()); err != nil { t.Fatalf("extractZIP(%s) error = %v", tt.in, err) } @@ -113,17 +108,7 @@ func Test_extractTARGZ(t *testing.T) { tmpDir, cleanup := testutil.NewTempDir(t) defer cleanup() - tf, err := os.Open(tarSrc) - if err != nil { - t.Fatalf("failed to open %q. error=%v", tt.in, err) - } - defer tf.Close() - st, err := tf.Stat() - if err != nil { - t.Fatal(err) - return - } - if err := extractTARGZ(tmpDir.Root(), tf, st.Size()); err != nil { + if err := archiver.Unarchive(tarSrc, tmpDir.Root()); err != nil { t.Fatalf("failed to extract %q. error=%v", tt.in, err) } @@ -210,11 +195,10 @@ func Test_download(t *testing.T) { fetcher Fetcher } tests := []struct { - name string - args args - wantReader io.ReaderAt - wantSize int64 - wantErr bool + name string + args args + wantData []byte + wantErr bool }{ { name: "successful fetch", @@ -223,9 +207,8 @@ func Test_download(t *testing.T) { verifier: newTrueVerifier(), fetcher: NewFileFetcher(filePath), }, - wantReader: bytes.NewReader(downloadOriginal), - wantSize: int64(len(downloadOriginal)), - wantErr: false, + wantData: downloadOriginal, + wantErr: false, }, { name: "wrong data fetch", @@ -234,14 +217,13 @@ func Test_download(t *testing.T) { verifier: newFalseVerifier(), fetcher: NewFileFetcher(filePath), }, - wantReader: bytes.NewReader(downloadOriginal), - wantSize: int64(len(downloadOriginal)), - wantErr: true, + wantData: downloadOriginal, + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - reader, size, err := download(tt.args.url, tt.args.verifier, tt.args.fetcher) + downloadedData, err := download(tt.args.url, tt.args.verifier, tt.args.fetcher) if (err != nil) != tt.wantErr { t.Errorf("download() error = %v, wantErr %v", err, tt.wantErr) return @@ -249,22 +231,10 @@ func Test_download(t *testing.T) { if tt.wantErr { return } - downloadedData, err := ioutil.ReadAll(io.NewSectionReader(reader, 0, size)) - if err != nil { - t.Errorf("failed to read download data: %v", err) - return - } - wantData, err := ioutil.ReadAll(io.NewSectionReader(tt.wantReader, 0, tt.wantSize)) - if err != nil { - t.Errorf("failed to read download data: %v", err) - return - } - if !bytes.Equal(downloadedData, wantData) { - t.Errorf("download() reader = %v, wantReader %v", reader, tt.wantReader) - } - if size != tt.wantSize { - t.Errorf("download() size = %v, wantReader %v", size, tt.wantSize) + if !bytes.Equal(downloadedData, tt.wantData) { + t.Errorf("download() data = %v, wantData %v", + string(downloadedData), string(tt.wantData)) } }) } @@ -301,24 +271,21 @@ func Test_detectMIMEType(t *testing.T) { args: args{ file: filepath.Join(testdataPath(), "test-with-directory.zip"), }, - want: "application/zip", - wantErr: false, + want: "application/zip", }, { name: "type tar.gz", args: args{ file: filepath.Join(testdataPath(), "test-with-nesting-with-directory-entries.tar.gz"), }, - want: "application/x-gzip", - wantErr: false, + want: "application/x-gzip", }, { name: "type bash-utf8", args: args{ file: filepath.Join(testdataPath(), "bash-utf8-file"), }, - want: "text/plain", - wantErr: false, + want: "text/plain", }, { @@ -326,71 +293,56 @@ func Test_detectMIMEType(t *testing.T) { args: args{ file: filepath.Join(testdataPath(), "bash-ascii-file"), }, - want: "text/plain", - wantErr: false, + want: "text/plain", }, { name: "type null", args: args{ file: filepath.Join(testdataPath(), "null-file"), }, - want: "text/plain", - wantErr: false, + want: "text/plain", }, { name: "512 zero bytes", args: args{ content: make([]byte, 512), }, - want: "application/octet-stream", - wantErr: false, + want: "application/octet-stream", }, { name: "1 zero bytes", args: args{ content: make([]byte, 1), }, - want: "application/octet-stream", - wantErr: false, + want: "application/octet-stream", }, { name: "0 zero bytes", args: args{ content: []byte{}, }, - want: "text/plain", - wantErr: false, + want: "text/plain", }, { name: "html", args: args{ content: []byte(""), }, - want: "text/html", - wantErr: false, + want: "text/html", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var at io.ReaderAt + data := tt.args.content if tt.args.file != "" { - fd, err := os.Open(tt.args.file) - if err != nil { - t.Errorf("failed to read file %s, err: %v", tt.args.file, err) - return + var err error + if data, err = ioutil.ReadFile(tt.args.file); err != nil { + t.Errorf("failed to read file %q, err: %v", tt.args.file, err) } - defer fd.Close() - at = fd - } else { - at = bytes.NewReader(tt.args.content) } - got, err := detectMIMEType(at) - if (err != nil) != tt.wantErr { - t.Errorf("detectMIMEType() error = %v, wantErr %v", err, tt.wantErr) - return - } + got := detectMIMEType(data) if got != tt.want { t.Errorf("detectMIMEType() = %v, want %v", got, tt.want) } @@ -398,64 +350,6 @@ func Test_detectMIMEType(t *testing.T) { } } -func Test_extractArchive(t *testing.T) { - oldextractors := defaultExtractors - defer func() { - defaultExtractors = oldextractors - }() - defaultExtractors = map[string]extractor{ - "application/octet-stream": func(targetDir string, read io.ReaderAt, size int64) error { return nil }, - "text/plain": func(targetDir string, read io.ReaderAt, size int64) error { return errors.New("fail test") }, - } - type args struct { - filename string - dst string - file string - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "test fail extraction", - args: args{ - filename: "", - dst: "", - file: filepath.Join(testdataPath(), "null-file"), - }, - wantErr: true, - }, - { - name: "test type not found extraction", - args: args{ - filename: "", - dst: "", - file: filepath.Join(testdataPath(), "test-with-nesting-with-directory-entries.tar.gz"), - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fd, err := os.Open(tt.args.file) - if err != nil { - t.Errorf("failed to read file %s, err: %v", tt.args.file, err) - return - } - st, err := fd.Stat() - if err != nil { - t.Errorf("failed to stat file %s, err: %v", tt.args.file, err) - return - } - - if err := extractArchive(tt.args.dst, fd, st.Size()); (err != nil) != tt.wantErr { - t.Errorf("extractArchive() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - func Test_suspiciousPath(t *testing.T) { tests := []struct { path string @@ -533,47 +427,40 @@ func Test_extractMaliciousArchive(t *testing.T) { }, } - for _, tt := range tests { - t.Run("tar.gz "+tt.name, func(t *testing.T) { - tmpDir, cleanup := testutil.NewTempDir(t) - defer cleanup() - - // do not use filepath.Join here, because it calls filepath.Clean on the result - reader, err := tarGZArchiveForTesting(map[string]string{tt.path: testContent}) - if err != nil { - t.Fatal(err) - } - - err = extractTARGZ(tmpDir.Root(), reader, reader.Size()) - if err == nil { - t.Errorf("Expected extractTARGZ to fail") - } else if !strings.HasPrefix(err.Error(), "refusing to unpack archive") { - t.Errorf("Found the wrong error: %s", err) - } - }) + modes := map[string]archiveMaker{ + "tar.gz": tarGZArchiveForTesting, + "zip": zipArchiveForTesting, } for _, tt := range tests { - t.Run("zip "+tt.name, func(t *testing.T) { - tmpDir, cleanup := testutil.NewTempDir(t) - defer cleanup() + for mode, maker := range modes { + t.Run(mode+" "+tt.name, func(t *testing.T) { + tmpDir, cleanup := testutil.NewTempDir(t) + defer cleanup() - // do not use filepath.Join here, because it calls filepath.Clean on the result - reader, err := zipArchiveReaderForTesting(map[string]string{tt.path: testContent}) - if err != nil { - t.Fatal(err) - } + // do not use filepath.Join here, because it calls filepath.Clean on the result + data, err := maker(map[string]string{tt.path: testContent}) + if err != nil { + t.Fatal(err) + } + archive := filepath.Join(tmpDir.Root(), "plugin."+mode) + if err := ioutil.WriteFile(archive, data, 0664); err != nil { + t.Errorf("Writing fails: %s", err) + } - err = extractZIP(tmpDir.Root(), reader, reader.Size()) - if err == nil { - t.Errorf("Expected extractZIP to fail") - } else if !strings.HasPrefix(err.Error(), "refusing to unpack archive") { - t.Errorf("Found the wrong error: %s", err) - } - }) + err = isSuspiciousArchive(archive) + if err == nil { + t.Errorf("Expected %s unarchive to fail", mode) + } else if !strings.Contains(err.Error(), "refusing to unpack archive") { + t.Errorf("Found the wrong error: %s", err) + } + }) + } } } +type archiveMaker func(map[string]string) ([]byte, error) + // tarGZArchiveForTesting creates an in-memory zip archive with entries from // the files map, where keys are the paths and values are the contents. // For example, to create an empty file `a` and another file `b/c`: @@ -581,7 +468,7 @@ func Test_extractMaliciousArchive(t *testing.T) { // "a": "", // "b/c": "nested content", // }) -func tarGZArchiveForTesting(files map[string]string) (*bytes.Reader, error) { +func tarGZArchiveForTesting(files map[string]string) ([]byte, error) { archiveBuffer := &bytes.Buffer{} gzArchiveBuffer := gzip.NewWriter(archiveBuffer) tw := tar.NewWriter(gzArchiveBuffer) @@ -597,7 +484,6 @@ func tarGZArchiveForTesting(files map[string]string) (*bytes.Reader, error) { if _, err := tw.Write([]byte(content)); err != nil { return nil, err } - } if err := tw.Close(); err != nil { return nil, err @@ -605,16 +491,16 @@ func tarGZArchiveForTesting(files map[string]string) (*bytes.Reader, error) { if err := gzArchiveBuffer.Close(); err != nil { return nil, err } - return bytes.NewReader(archiveBuffer.Bytes()), nil + return archiveBuffer.Bytes(), nil } -// zipArchiveReaderForTesting creates an in-memory zip archive with entries from +// zipArchiveForTesting creates an in-memory zip archive with entries from // the files map, where keys are the paths and values are the contents. Note that // entries with empty content just create a directory. The zip spec requires that // parent directories are explicitly listed in the archive, so this must be done // for nested entries. For example, to create a file at `a/b/c`, you must pass: // map[string]string{"a": "", "a/b": "", "a/b/c": "nested content"} -func zipArchiveReaderForTesting(files map[string]string) (*bytes.Reader, error) { +func zipArchiveForTesting(files map[string]string) ([]byte, error) { archiveBuffer := &bytes.Buffer{} zw := zip.NewWriter(archiveBuffer) for path, content := range files { @@ -632,5 +518,5 @@ func zipArchiveReaderForTesting(files map[string]string) (*bytes.Reader, error) if err := zw.Close(); err != nil { return nil, err } - return bytes.NewReader(archiveBuffer.Bytes()), nil + return archiveBuffer.Bytes(), nil }