From fa24b46533e6605d8d4a795c557f968f4c6fbacb Mon Sep 17 00:00:00 2001 From: Anurag Bandyopadhyay Date: Thu, 11 Jun 2026 19:33:25 +0530 Subject: [PATCH] refactor: drop gocsv to stdlib encoding/csv --- cmd/tuple/read.go | 59 ++++++++++++++--- cmd/tuple/read_test.go | 44 +++++++++++- go.mod | 1 - go.sum | 2 - internal/output/marshal.go | 53 ++++++++++++++- internal/output/marshal_test.go | 114 ++++++++++++++++++++++++++++++++ 6 files changed, 256 insertions(+), 17 deletions(-) create mode 100644 internal/output/marshal_test.go diff --git a/cmd/tuple/read.go b/cmd/tuple/read.go index 3f3b8118..9574b58b 100644 --- a/cmd/tuple/read.go +++ b/cmd/tuple/read.go @@ -48,18 +48,57 @@ type readResponse struct { } type readResponseCSVDTO struct { - UserType string `csv:"user_type"` - UserID string `csv:"user_id"` - UserRelation string `csv:"user_relation,omitempty"` - Relation string `csv:"relation"` - ObjectType string `csv:"object_type"` - ObjectID string `csv:"object_id"` - ConditionName string `csv:"condition_name,omitempty"` - ConditionContext string `csv:"condition_context,omitempty"` + UserType string + UserID string + UserRelation string + Relation string + ObjectType string + ObjectID string + ConditionName string + ConditionContext string } -func (r readResponse) toCsvDTO() ([]readResponseCSVDTO, error) { - readResponseDTO := make([]readResponseCSVDTO, 0, len(r.simple)) +type readResponseCSVDTOList []readResponseCSVDTO + +var readResponseCSVHeaders = []string{ + "user_type", + "user_id", + "user_relation", + "relation", + "object_type", + "object_id", + "condition_name", + "condition_context", +} + +func (dto readResponseCSVDTO) MarshalCSV() ([]string, error) { + return []string{ + dto.UserType, + dto.UserID, + dto.UserRelation, + dto.Relation, + dto.ObjectType, + dto.ObjectID, + dto.ConditionName, + dto.ConditionContext, + }, nil +} + +func (l readResponseCSVDTOList) CSVHeaders() []string { + return readResponseCSVHeaders +} + +func (l readResponseCSVDTOList) CSVRecords() []output.CSVMarshaler { + records := make([]output.CSVMarshaler, len(l)) + for i, dto := range l { + records[i] = dto + } + + return records +} + +func (r readResponse) toCsvDTO() (readResponseCSVDTOList, error) { + readResponseDTO := make(readResponseCSVDTOList, 0, len(r.simple)) for _, readRes := range r.simple { // Handle Condition diff --git a/cmd/tuple/read_test.go b/cmd/tuple/read_test.go index 54007c08..a05ff6a8 100644 --- a/cmd/tuple/read_test.go +++ b/cmd/tuple/read_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" openfga "github.com/openfga/go-sdk" "github.com/openfga/go-sdk/client" @@ -438,7 +439,7 @@ func TestReadResponseCSVDTOParser(t *testing.T) { testCases := []struct { readRes readResponse - expected []readResponseCSVDTO + expected readResponseCSVDTOList }{ { readRes: readResponse{ @@ -460,7 +461,7 @@ func TestReadResponseCSVDTOParser(t *testing.T) { }, }, }, - expected: []readResponseCSVDTO{ + expected: readResponseCSVDTOList{ { UserType: "user", UserID: "anne", @@ -487,6 +488,45 @@ func TestReadResponseCSVDTOParser(t *testing.T) { } } +func TestReadResponseCSVDTOListMarshalCSV(t *testing.T) { + t.Parallel() + + list := readResponseCSVDTOList{ + { + UserType: "user", + UserID: "anne", + Relation: "reader", + ObjectType: "document", + ObjectID: "secret.doc", + ConditionName: "inOfficeIP", + ConditionContext: `{"ip_addr":"10.0.0.1"}`, + }, + { + UserType: "user", + UserID: "john", + Relation: "writer", + ObjectType: "document", + ObjectID: "abc.doc", + }, + } + + assert.Equal(t, readResponseCSVHeaders, list.CSVHeaders()) + + rows := make([][]string, 0, len(list)) + + for _, record := range list.CSVRecords() { + row, err := record.MarshalCSV() + require.NoError(t, err) + + rows = append(rows, row) + } + + assert.Equal(t, [][]string{ + {"user", "anne", "", "reader", "document", "secret.doc", "inOfficeIP", `{"ip_addr":"10.0.0.1"}`}, + {"user", "john", "", "writer", "document", "abc.doc", "", ""}, + }, rows) +} + func toPointer[T any](p T) *T { return &p } diff --git a/go.mod b/go.mod index f48b781e..01fcafdf 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.25.7 toolchain go1.26.4 require ( - github.com/gocarina/gocsv v0.0.0-20240520201108-78e41c74b4b1 github.com/hashicorp/go-multierror v1.1.1 github.com/mattn/go-isatty v0.0.22 github.com/muesli/mango-cobra v1.3.0 diff --git a/go.sum b/go.sum index c9c9bf3b..9f6e4a69 100644 --- a/go.sum +++ b/go.sum @@ -74,8 +74,6 @@ github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= -github.com/gocarina/gocsv v0.0.0-20240520201108-78e41c74b4b1 h1:FWNFq4fM1wPfcK40yHE5UO3RUdSNPaBC+j3PokzA6OQ= -github.com/gocarina/gocsv v0.0.0-20240520201108-78e41c74b4b1/go.mod h1:5YoVOkjYAQumqlV356Hj3xeYh4BdZuLE0/nRkf2NKkI= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= diff --git a/internal/output/marshal.go b/internal/output/marshal.go index 828d0ad6..2055e988 100644 --- a/internal/output/marshal.go +++ b/internal/output/marshal.go @@ -18,11 +18,13 @@ limitations under the License. package output import ( + "bytes" + "encoding/csv" "encoding/json" + "errors" "fmt" "os" - "github.com/gocarina/gocsv" "gopkg.in/yaml.v3" "github.com/mattn/go-isatty" @@ -74,7 +76,7 @@ func (prt *csvPrinter) DisplayColor(data any) error { } func (prt *csvPrinter) DisplayNoColor(data any) error { - b, err := gocsv.MarshalBytes(data) + b, err := marshalCSV(data) if err != nil { return fmt.Errorf("unable to marshal CSV with error: %w", err) } @@ -84,6 +86,53 @@ func (prt *csvPrinter) DisplayNoColor(data any) error { return nil } +var errNotCSVMarshaler = errors.New("type does not implement output.CSVRecordSet") + +// CSVMarshaler is implemented by a type that can render itself as a single CSV record. +type CSVMarshaler interface { + MarshalCSV() ([]string, error) +} + +// CSVRecordSet is implemented by a collection that can render itself as CSV: +// a header row followed by one record per element. +type CSVRecordSet interface { + CSVHeaders() []string + CSVRecords() []CSVMarshaler +} + +func marshalCSV(data any) ([]byte, error) { + recordSet, ok := data.(CSVRecordSet) + if !ok { + return nil, fmt.Errorf("cannot marshal %T to csv: %w", data, errNotCSVMarshaler) + } + + buffer := &bytes.Buffer{} + writer := csv.NewWriter(buffer) + + if err := writer.Write(recordSet.CSVHeaders()); err != nil { + return nil, fmt.Errorf("failed to write csv header: %w", err) + } + + for _, record := range recordSet.CSVRecords() { + row, err := record.MarshalCSV() + if err != nil { + return nil, fmt.Errorf("failed to marshal csv record: %w", err) + } + + if err := writer.Write(row); err != nil { + return nil, fmt.Errorf("failed to write csv record: %w", err) + } + } + + writer.Flush() + + if err := writer.Error(); err != nil { + return nil, fmt.Errorf("failed to flush csv: %w", err) + } + + return buffer.Bytes(), nil +} + func (prt *yamlPrinter) DisplayColor(data any) error { return prt.DisplayNoColor(data) } diff --git a/internal/output/marshal_test.go b/internal/output/marshal_test.go new file mode 100644 index 00000000..2e533939 --- /dev/null +++ b/internal/output/marshal_test.go @@ -0,0 +1,114 @@ +package output + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type fakeCSVRow struct { + fields []string + err error +} + +func (r fakeCSVRow) MarshalCSV() ([]string, error) { + return r.fields, r.err +} + +type fakeCSVRecordSet struct { + headers []string + records []CSVMarshaler +} + +func (f fakeCSVRecordSet) CSVHeaders() []string { + return f.headers +} + +func (f fakeCSVRecordSet) CSVRecords() []CSVMarshaler { + return f.records +} + +func recordSet(headers []string, rows [][]string) fakeCSVRecordSet { + records := make([]CSVMarshaler, len(rows)) + for i, row := range rows { + records[i] = fakeCSVRow{fields: row} + } + + return fakeCSVRecordSet{headers: headers, records: records} +} + +func TestMarshalCSV(t *testing.T) { + t.Parallel() + + headers := []string{"user_type", "user_id", "relation", "object_type", "object_id", "condition_context"} + + tests := []struct { + name string + records [][]string + expected string + }{ + { + name: "no records writes only headers", + records: nil, + expected: "user_type,user_id,relation,object_type,object_id,condition_context\n", + }, + { + name: "single record", + records: [][]string{ + {"user", "john", "writer", "document", "abc.doc", ""}, + }, + expected: "user_type,user_id,relation,object_type,object_id,condition_context\n" + + "user,john,writer,document,abc.doc,\n", + }, + { + name: "multiple records", + records: [][]string{ + {"user", "anne", "reader", "document", "x", ""}, + {"group", "eng", "owner", "repo", "y", ""}, + }, + expected: "user_type,user_id,relation,object_type,object_id,condition_context\n" + + "user,anne,reader,document,x,\n" + + "group,eng,owner,repo,y,\n", + }, + { + name: "values with commas, quotes and newlines are escaped", + records: [][]string{ + {"user", "a,b", "say \"hi\"", "doc", "line\nbreak", `{"ip_addr":"10.0.0.1"}`}, + }, + expected: "user_type,user_id,relation,object_type,object_id,condition_context\n" + + "user,\"a,b\",\"say \"\"hi\"\"\",doc,\"line\nbreak\",\"{\"\"ip_addr\"\":\"\"10.0.0.1\"\"}\"\n", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + got, err := marshalCSV(recordSet(headers, test.records)) + require.NoError(t, err) + assert.Equal(t, test.expected, string(got)) + }) + } +} + +func TestMarshalCSVNotARecordSet(t *testing.T) { + t.Parallel() + + _, err := marshalCSV([]string{"a", "b"}) + assert.ErrorIs(t, err, errNotCSVMarshaler) +} + +func TestMarshalCSVRecordError(t *testing.T) { + t.Parallel() + + sentinel := errors.New("boom") + set := fakeCSVRecordSet{ + headers: []string{"col"}, + records: []CSVMarshaler{fakeCSVRow{err: sentinel}}, + } + + _, err := marshalCSV(set) + assert.ErrorIs(t, err, sentinel) +}