forked from pocketbase/pocketbase
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added helper archive package to create and extract zips
- Loading branch information
1 parent
dfabfa7
commit 90abe16
Showing
4 changed files
with
283 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
package archive | ||
|
||
import ( | ||
"archive/zip" | ||
"io" | ||
"io/fs" | ||
"os" | ||
) | ||
|
||
// Create creates a new zip archive from src dir content and saves it in dest path. | ||
func Create(src, dest string) error { | ||
zf, err := os.Create(dest) | ||
if err != nil { | ||
return err | ||
} | ||
defer zf.Close() | ||
|
||
zw := zip.NewWriter(zf) | ||
defer zw.Close() | ||
|
||
if err := zipAddFS(zw, os.DirFS(src)); err != nil { | ||
// try to cleanup the created zip file | ||
os.Remove(dest) | ||
|
||
return err | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// note remove after similar method is added in the std lib (https://github.com/golang/go/issues/54898) | ||
func zipAddFS(w *zip.Writer, fsys fs.FS) error { | ||
return fs.WalkDir(fsys, ".", func(name string, d fs.DirEntry, err error) error { | ||
if err != nil { | ||
return err | ||
} | ||
|
||
if d.IsDir() { | ||
return nil | ||
} | ||
|
||
info, err := d.Info() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
h, err := zip.FileInfoHeader(info) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
h.Name = name | ||
h.Method = zip.Deflate | ||
|
||
fw, err := w.CreateHeader(h) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
f, err := fsys.Open(name) | ||
if err != nil { | ||
return err | ||
} | ||
defer f.Close() | ||
|
||
_, err = io.Copy(fw, f) | ||
|
||
return err | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
package archive_test | ||
|
||
import ( | ||
"os" | ||
"path/filepath" | ||
"testing" | ||
|
||
"github.com/pocketbase/pocketbase/tools/archive" | ||
) | ||
|
||
func TestCreateFailure(t *testing.T) { | ||
testDir := createTestDir(t) | ||
defer os.RemoveAll(testDir) | ||
|
||
zipPath := filepath.Join(os.TempDir(), "pb_test.zip") | ||
defer os.RemoveAll(zipPath) | ||
|
||
missingDir := filepath.Join(os.TempDir(), "missing") | ||
|
||
if err := archive.Create(missingDir, zipPath); err == nil { | ||
t.Fatal("Expected to fail due to missing directory or file") | ||
} | ||
|
||
if _, err := os.Stat(zipPath); err == nil { | ||
t.Fatalf("Expected the zip file not to be created") | ||
} | ||
} | ||
|
||
func TestCreateSuccess(t *testing.T) { | ||
testDir := createTestDir(t) | ||
defer os.RemoveAll(testDir) | ||
|
||
zipName := "pb_test.zip" | ||
zipPath := filepath.Join(os.TempDir(), zipName) | ||
defer os.RemoveAll(zipPath) | ||
|
||
// zip testDir content | ||
if err := archive.Create(testDir, zipPath); err != nil { | ||
t.Fatalf("Failed to create archive: %v", err) | ||
} | ||
|
||
info, err := os.Stat(zipPath) | ||
if err != nil { | ||
t.Fatalf("Failed to retrieve the generated zip file: %v", err) | ||
} | ||
|
||
if name := info.Name(); name != zipName { | ||
t.Fatalf("Expected zip with name %q, got %q", zipName, name) | ||
} | ||
|
||
expectedSize := int64(300) | ||
if size := info.Size(); size != expectedSize { | ||
t.Fatalf("Expected zip with size %d, got %d", expectedSize, size) | ||
} | ||
} | ||
|
||
// --- | ||
|
||
// note: make sure to call os.RemoveAll(dir) after you are done | ||
// working with the created test dir. | ||
func createTestDir(t *testing.T) string { | ||
dir, err := os.MkdirTemp(os.TempDir(), "pb_zip_test") | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if err := os.MkdirAll(filepath.Join(dir, "a/b/c"), os.ModePerm); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
sub1, err := os.OpenFile(filepath.Join(dir, "a/sub1.txt"), os.O_WRONLY|os.O_CREATE, 0644) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
sub1.Close() | ||
|
||
sub2, err := os.OpenFile(filepath.Join(dir, "a/b/c/sub2.txt"), os.O_WRONLY|os.O_CREATE, 0644) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
sub2.Close() | ||
|
||
return dir | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package archive | ||
|
||
import ( | ||
"archive/zip" | ||
"fmt" | ||
"io" | ||
"os" | ||
"path/filepath" | ||
"strings" | ||
) | ||
|
||
// Extract extracts the zip archive at src to dest. | ||
func Extract(src, dest string) error { | ||
zr, err := zip.OpenReader(src) | ||
if err != nil { | ||
return err | ||
} | ||
defer zr.Close() | ||
|
||
// normalize dest path to check later for Zip Slip | ||
dest = filepath.Clean(dest) + string(os.PathSeparator) | ||
|
||
for _, f := range zr.File { | ||
err := extractFile(f, dest) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// extractFile extracts the provided zipFile into "basePath/zipFileName" path, | ||
// creating all the necessary path directories. | ||
func extractFile(zipFile *zip.File, basePath string) error { | ||
path := filepath.Join(basePath, zipFile.Name) | ||
|
||
// check for Zip Slip | ||
if !strings.HasPrefix(path, basePath) { | ||
return fmt.Errorf("invalid file path: %s", path) | ||
} | ||
|
||
r, err := zipFile.Open() | ||
if err != nil { | ||
return err | ||
} | ||
defer r.Close() | ||
|
||
if zipFile.FileInfo().IsDir() { | ||
if err := os.MkdirAll(path, os.ModePerm); err != nil { | ||
return err | ||
} | ||
} else { | ||
// ensure that the file path directories are created | ||
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil { | ||
return err | ||
} | ||
|
||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, zipFile.Mode()) | ||
if err != nil { | ||
return err | ||
} | ||
defer f.Close() | ||
|
||
_, err = io.Copy(f, r) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package archive_test | ||
|
||
import ( | ||
"os" | ||
"path/filepath" | ||
"testing" | ||
|
||
"github.com/pocketbase/pocketbase/tools/archive" | ||
) | ||
|
||
func TestExtractFailure(t *testing.T) { | ||
testDir := createTestDir(t) | ||
defer os.RemoveAll(testDir) | ||
|
||
missingZipPath := filepath.Join(os.TempDir(), "pb_missing_test.zip") | ||
extractPath := filepath.Join(os.TempDir(), "pb_zip_extract") | ||
defer os.RemoveAll(extractPath) | ||
|
||
if err := archive.Extract(missingZipPath, extractPath); err == nil { | ||
t.Fatal("Expected Extract to fail due to missing zipPath") | ||
} | ||
|
||
if _, err := os.Stat(extractPath); err == nil { | ||
t.Fatalf("Expected %q to not be created", extractPath) | ||
} | ||
} | ||
|
||
func TestExtractSuccess(t *testing.T) { | ||
testDir := createTestDir(t) | ||
defer os.RemoveAll(testDir) | ||
|
||
zipPath := filepath.Join(os.TempDir(), "pb_test.zip") | ||
defer os.RemoveAll(zipPath) | ||
|
||
extractPath := filepath.Join(os.TempDir(), "pb_zip_extract") | ||
defer os.RemoveAll(extractPath) | ||
|
||
// zip testDir content | ||
if err := archive.Create(testDir, zipPath); err != nil { | ||
t.Fatalf("Failed to create archive: %v", err) | ||
} | ||
|
||
if err := archive.Extract(zipPath, extractPath); err != nil { | ||
t.Fatalf("Failed to extract %q in %q", zipPath, extractPath) | ||
} | ||
|
||
pathsToCheck := []string{ | ||
filepath.Join(extractPath, "a/sub1.txt"), | ||
filepath.Join(extractPath, "a/b/c/sub2.txt"), | ||
} | ||
|
||
for _, p := range pathsToCheck { | ||
if _, err := os.Stat(p); err != nil { | ||
t.Fatalf("Failed to retrieve extracted file %q: %v", p, err) | ||
} | ||
} | ||
} |