Skip to content

Commit

Permalink
feat: support referencing csv file in fga store file
Browse files Browse the repository at this point in the history
  • Loading branch information
rhamzeh committed Feb 6, 2024
1 parent 02fd2e4 commit 489a2f2
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 239 deletions.
2 changes: 1 addition & 1 deletion cmd/model/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ func TestWriteModelFail(t *testing.T) {
mockFgaClient.EXPECT().WriteAuthorizationModel(context.Background()).Return(mockRequest)

model := authorizationmodel.AuthzModel{}
err = model.ReadFromJSONString(modelJSONTxt)

err = model.ReadFromJSONString(modelJSONTxt)
if err != nil {
return
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/store/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,17 @@ func importStore(
return fmt.Errorf("failed to write model due to %w", err)
}
}
fgaClient, err = clientConfig.GetFgaClient() //nolint:wsl

fgaClient, err = clientConfig.GetFgaClient()
if err != nil {
return fmt.Errorf("failed to initialize FGA Client due to %w", err)
}

writeRequest := client.ClientWriteRequest{
Writes: storeData.Tuples,
}
_, err = tuple.ImportTuples(fgaClient, writeRequest, maxTuplesPerWrite, maxParallelRequests)

_, err = tuple.ImportTuples(fgaClient, writeRequest, maxTuplesPerWrite, maxParallelRequests)
if err != nil {
return err //nolint:wrapcheck
}
Expand Down
195 changes: 3 additions & 192 deletions cmd/tuple/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,16 @@ limitations under the License.
package tuple

import (
"bytes"
"context"
"encoding/csv"
"errors"
"fmt"
"io"
"os"
"path"
"strings"

"github.com/openfga/cli/internal/clierrors"

openfga "github.com/openfga/go-sdk"
"github.com/openfga/go-sdk/client"
"github.com/spf13/cobra"
flag "github.com/spf13/pflag"
"gopkg.in/yaml.v3"

"github.com/openfga/cli/internal/cmdutils"
"github.com/openfga/cli/internal/output"
"github.com/openfga/cli/internal/tuplefile"
)

const writeCommandArgumentsCount = 3
Expand Down Expand Up @@ -137,9 +127,9 @@ func writeTuplesFromFile(flags *flag.FlagSet, fgaClient *client.OpenFgaClient) e
return fmt.Errorf("failed to parse parallel requests: %w", err)
}

tuples, err := parseTuplesFileData(fileName)
tuples, err := tuplefile.ReadTupleFile(fileName)
if err != nil {
return err
return err //nolint:wrapcheck
}

writeRequest := client.ClientWriteRequest{
Expand All @@ -154,185 +144,6 @@ func writeTuplesFromFile(flags *flag.FlagSet, fgaClient *client.OpenFgaClient) e
return output.Display(response) //nolint:wrapcheck
}

func parseTuplesFileData(fileName string) ([]client.ClientTupleKey, error) {
data, err := os.ReadFile(fileName)
if err != nil {
return nil, fmt.Errorf("failed to read file %q: %w", fileName, err)
}

var tuples []client.ClientTupleKey

switch path.Ext(fileName) {
case ".json", ".yaml", ".yml":
err = yaml.Unmarshal(data, &tuples)
case ".csv":
err = parseTuplesFromCSV(data, &tuples)
default:
err = fmt.Errorf("unsupported file format %q", path.Ext(fileName)) //nolint:goerr113
}

if err != nil {
return nil, fmt.Errorf("failed to parse input tuples: %w", err)
}

return tuples, nil
}

func parseTuplesFromCSV(data []byte, tuples *[]client.ClientTupleKey) error {
reader := csv.NewReader(bytes.NewReader(data))

columns, err := readHeaders(reader)
if err != nil {
return err
}

for index := 0; true; index++ {
tuple, err := reader.Read()
if err != nil {
if errors.Is(err, io.EOF) {
break
}

return fmt.Errorf("failed to read tuple from csv file: %w", err)
}

tupleUserKey := tuple[columns.UserType] + ":" + tuple[columns.UserID]
if columns.UserRelation != -1 && tuple[columns.UserRelation] != "" {
tupleUserKey += "#" + tuple[columns.UserRelation]
}

condition, err := parseConditionColumnsForRow(columns, tuple, index)
if err != nil {
return err
}

tupleKey := client.ClientTupleKey{
User: tupleUserKey,
Relation: tuple[columns.Relation],
Object: tuple[columns.ObjectType] + ":" + tuple[columns.ObjectID],
Condition: condition,
}

*tuples = append(*tuples, tupleKey)
}

return nil
}

func parseConditionColumnsForRow(columns *csvColumns, tuple []string, index int) (*openfga.RelationshipCondition, error) {
var condition *openfga.RelationshipCondition

if columns.ConditionName != -1 && tuple[columns.ConditionName] != "" {
conditionContext := &(map[string]interface{}{})

if columns.ConditionContext != -1 {
var err error

conditionContext, err = cmdutils.ParseQueryContextInner(tuple[columns.ConditionContext])
if err != nil {
return nil, fmt.Errorf("failed to read condition context on line %d: %w", index, err)
}
}

condition = &openfga.RelationshipCondition{
Name: tuple[columns.ConditionName],
Context: conditionContext,
}
}

return condition, nil
}

type csvColumns struct {
UserType int
UserID int
UserRelation int
Relation int
ObjectType int
ObjectID int
ConditionName int
ConditionContext int
}

func (columns *csvColumns) setHeaderIndex(headerName string, index int) error {
switch headerName {
case "user_type":
columns.UserType = index
case "user_id":
columns.UserID = index
case "user_relation":
columns.UserRelation = index
case "relation":
columns.Relation = index
case "object_type":
columns.ObjectType = index
case "object_id":
columns.ObjectID = index
case "condition_name":
columns.ConditionName = index
case "condition_context":
columns.ConditionContext = index
default:
return fmt.Errorf("invalid header %q, valid headers are user_type,user_id,user_relation,relation,object_type,object_id,condition_name,condition_context", headerName) //nolint:goerr113
}

return nil
}

func (columns *csvColumns) validate() error {
if columns.UserType == -1 {
return clierrors.MissingRequiredCsvHeaderError("user_type") //nolint:wrapcheck
}

if columns.UserID == -1 {
return clierrors.MissingRequiredCsvHeaderError("user_id") //nolint:wrapcheck
}

if columns.Relation == -1 {
return clierrors.MissingRequiredCsvHeaderError("relation") //nolint:wrapcheck
}

if columns.ObjectType == -1 {
return clierrors.MissingRequiredCsvHeaderError("object_type") //nolint:wrapcheck
}

if columns.ObjectID == -1 {
return clierrors.MissingRequiredCsvHeaderError("object_id") //nolint:wrapcheck
}

if columns.ConditionContext != -1 && columns.ConditionName == -1 {
return errors.New("missing \"condition_name\" header which is required when \"condition_context\" is present") //nolint:goerr113
}

return nil
}

func readHeaders(reader *csv.Reader) (*csvColumns, error) {
headers, err := reader.Read()
if err != nil {
return nil, fmt.Errorf("failed to read csv headers: %w", err)
}

columns := &csvColumns{
UserType: -1,
UserID: -1,
UserRelation: -1,
Relation: -1,
ObjectType: -1,
ObjectID: -1,
ConditionName: -1,
ConditionContext: -1,
}
for index, header := range headers {
err = columns.setHeaderIndex(strings.TrimSpace(header), index)
if err != nil {
return nil, err
}
}

return columns, columns.validate()
}

func init() {
writeCmd.Flags().String("model-id", "", "Model ID")
writeCmd.Flags().String("file", "", "Tuples file")
Expand Down
4 changes: 3 additions & 1 deletion cmd/tuple/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"github.com/openfga/go-sdk/client"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/openfga/cli/internal/tuplefile"
)

func TestParseTuplesFileData(t *testing.T) { //nolint:funlen
Expand Down Expand Up @@ -209,7 +211,7 @@ func TestParseTuplesFileData(t *testing.T) { //nolint:funlen
t.Run(test.name, func(t *testing.T) {
t.Parallel()

actualTuples, err := parseTuplesFileData(test.file)
actualTuples, err := tuplefile.ReadTupleFile(test.file)

if test.expectedError != "" {
require.EqualError(t, err, test.expectedError)
Expand Down
2 changes: 1 addition & 1 deletion internal/storetest/read-from-input.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func ReadFromFile(fileName string, basePath string) (authorizationmodel.ModelFor

decoder := yaml.NewDecoder(testFile)
decoder.KnownFields(true)
err = decoder.Decode(&storeData)

err = decoder.Decode(&storeData)
if err != nil {
return format, nil, fmt.Errorf("failed to unmarshal file %s due to %w", fileName, err)
}
Expand Down
45 changes: 3 additions & 42 deletions internal/storetest/storedata.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,14 @@ limitations under the License.
package storetest

import (
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path"

"github.com/openfga/go-sdk/client"

"gopkg.in/yaml.v3"

"github.com/openfga/cli/internal/authorizationmodel"
"github.com/openfga/cli/internal/tuplefile"
)

type ModelTestCheck struct {
Expand Down Expand Up @@ -96,7 +92,7 @@ func (storeData *StoreData) LoadTuples(basePath string) error {
var errs error

if storeData.TupleFile != "" {
tuples, err := readTupleFile(path.Join(basePath, storeData.TupleFile))
tuples, err := tuplefile.ReadTupleFile(path.Join(basePath, storeData.TupleFile))
if err != nil {
errs = fmt.Errorf("failed to process global tuple %s file due to %w", storeData.TupleFile, err)
} else {
Expand All @@ -110,7 +106,7 @@ func (storeData *StoreData) LoadTuples(basePath string) error {
continue
}

tuples, err := readTupleFile(path.Join(basePath, test.TupleFile))
tuples, err := tuplefile.ReadTupleFile(path.Join(basePath, test.TupleFile))
if err != nil {
errs = errors.Join(
errs,
Expand All @@ -127,38 +123,3 @@ func (storeData *StoreData) LoadTuples(basePath string) error {

return nil
}

func readTupleFile(tuplePath string) ([]client.ClientContextualTupleKey, error) {
var tuples []client.ClientContextualTupleKey

tupleFile, err := os.Open(tuplePath)
if err != nil {
return nil, err //nolint:wrapcheck
}
defer tupleFile.Close()

switch path.Ext(tuplePath) {
case ".json":
contents, err := io.ReadAll(tupleFile)
if err != nil {
return nil, err //nolint:wrapcheck
}

err = json.Unmarshal(contents, &tuples)
if err != nil {
return nil, err //nolint:wrapcheck
}
case ".yaml", ".yml":
decoder := yaml.NewDecoder(tupleFile)
decoder.KnownFields(true)

err = decoder.Decode(&tuples)
if err != nil {
return nil, err //nolint:wrapcheck
}
default:
return nil, fmt.Errorf("unsupported file format %s", path.Ext(tuplePath)) //nolint:goerr113
}

return tuples, nil
}
Loading

0 comments on commit 489a2f2

Please sign in to comment.