Skip to content

Commit 16cc778

Browse files
author
tengu-alt
committed
scan to any were implemented for all simple types
1 parent 974fa12 commit 16cc778

File tree

2 files changed

+263
-1
lines changed

2 files changed

+263
-1
lines changed

cassandra_test.go

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import (
4444
"time"
4545
"unicode"
4646

47-
inf "gopkg.in/inf.v0"
47+
"gopkg.in/inf.v0"
4848
)
4949

5050
func TestEmptyHosts(t *testing.T) {
@@ -3288,3 +3288,163 @@ func TestQuery_NamedValues(t *testing.T) {
32883288
t.Fatal(err)
32893289
}
32903290
}
3291+
3292+
func TestScanToAny(t *testing.T) {
3293+
session := createSession(t)
3294+
defer session.Close()
3295+
ctx := context.Background()
3296+
3297+
dataTypes := []struct {
3298+
tableName string
3299+
createQuery string
3300+
insertQuery string
3301+
expectedVal interface{}
3302+
}{
3303+
{
3304+
"scan_to_any_varchar",
3305+
"CREATE TABLE IF NOT EXISTS scan_to_any_varchar (id int PRIMARY KEY, val varchar)",
3306+
"INSERT INTO scan_to_any_varchar (id, val) VALUES (?, ?)",
3307+
"test",
3308+
},
3309+
{
3310+
"scan_to_any_bool",
3311+
"CREATE TABLE IF NOT EXISTS scan_to_any_bool (id int PRIMARY KEY, val boolean)",
3312+
"INSERT INTO scan_to_any_bool (id, val) VALUES (?, ?)",
3313+
true,
3314+
},
3315+
{
3316+
"scan_to_any_int",
3317+
"CREATE TABLE IF NOT EXISTS scan_to_any_int (id int PRIMARY KEY, val int)",
3318+
"INSERT INTO scan_to_any_int (id, val) VALUES (?, ?)",
3319+
42,
3320+
},
3321+
{
3322+
"scan_to_any_float",
3323+
"CREATE TABLE IF NOT EXISTS scan_to_any_float (id int PRIMARY KEY, val float)",
3324+
"INSERT INTO scan_to_any_float (id, val) VALUES (?, ?)",
3325+
float32(3.14),
3326+
},
3327+
{
3328+
"scan_to_any_double",
3329+
"CREATE TABLE IF NOT EXISTS scan_to_any_double (id int PRIMARY KEY, val double)",
3330+
"INSERT INTO scan_to_any_double (id, val) VALUES (?, ?)",
3331+
3.14159,
3332+
},
3333+
{
3334+
"scan_to_any_decimal",
3335+
"CREATE TABLE IF NOT EXISTS scan_to_any_decimal (id int PRIMARY KEY, val decimal)",
3336+
"INSERT INTO scan_to_any_decimal (id, val) VALUES (?, ?)",
3337+
inf.NewDec(12345, 2), // Example decimal value
3338+
},
3339+
{
3340+
"scan_to_any_time",
3341+
"CREATE TABLE IF NOT EXISTS scan_to_any_time (id int PRIMARY KEY, val time)",
3342+
"INSERT INTO scan_to_any_time (id, val) VALUES (?, ?)",
3343+
time.Duration(1000),
3344+
},
3345+
{
3346+
"scan_to_any_timestamp",
3347+
"CREATE TABLE IF NOT EXISTS scan_to_any_timestamp (id int PRIMARY KEY, val timestamp)",
3348+
"INSERT INTO scan_to_any_timestamp (id, val) VALUES (?, ?)",
3349+
time.Now().UTC().Truncate(time.Millisecond),
3350+
},
3351+
{
3352+
"scan_to_any_inet",
3353+
"CREATE TABLE IF NOT EXISTS scan_to_any_inet (id int PRIMARY KEY, val inet)",
3354+
"INSERT INTO scan_to_any_inet (id, val) VALUES (?, ?)",
3355+
net.ParseIP("192.168.0.1"),
3356+
},
3357+
{
3358+
"scan_to_any_uuid",
3359+
"CREATE TABLE IF NOT EXISTS scan_to_any_uuid (id int PRIMARY KEY, val uuid)",
3360+
"INSERT INTO scan_to_any_uuid (id, val) VALUES (?, ?)",
3361+
TimeUUID().String(),
3362+
},
3363+
{
3364+
"scan_to_any_date",
3365+
"CREATE TABLE IF NOT EXISTS scan_to_any_date (id int PRIMARY KEY, val date)",
3366+
"INSERT INTO scan_to_any_date (id, val) VALUES (?, ?)",
3367+
time.Now().UTC().Truncate(time.Hour * 24),
3368+
},
3369+
{
3370+
"scan_to_any_duration",
3371+
"CREATE TABLE IF NOT EXISTS scan_to_any_duration (id int PRIMARY KEY, val duration)",
3372+
"INSERT INTO scan_to_any_duration (id, val) VALUES (?, ?)",
3373+
Duration{0, 0, 123},
3374+
},
3375+
}
3376+
3377+
for _, dt := range dataTypes {
3378+
t.Run(fmt.Sprintf("Test_%s", dt.tableName), func(t *testing.T) {
3379+
if err := session.Query(dt.createQuery).WithContext(ctx).Exec(); err != nil {
3380+
t.Fatal(err)
3381+
}
3382+
3383+
if err := session.Query(dt.insertQuery, 1, dt.expectedVal).WithContext(ctx).Exec(); err != nil {
3384+
t.Fatal(err)
3385+
}
3386+
3387+
var out interface{}
3388+
if err := session.Query(fmt.Sprintf("SELECT val FROM %s WHERE id = 1", dt.tableName)).WithContext(ctx).Scan(&out); err != nil {
3389+
t.Fatal(err)
3390+
}
3391+
3392+
if err := session.Query(fmt.Sprintf("DROP TABLE %s", dt.tableName)).WithContext(ctx).Exec(); err != nil {
3393+
t.Fatal(err)
3394+
}
3395+
3396+
switch dt.tableName {
3397+
case "scan_to_any_decimal":
3398+
result, ok := out.(inf.Dec)
3399+
if !ok {
3400+
t.Fatal("expected inf.Dec, got", out)
3401+
}
3402+
expected := inf.NewDec(12345, 2)
3403+
3404+
if result.Cmp(expected) != 0 {
3405+
t.Fatalf("expected %v, got %v", expected, out)
3406+
}
3407+
case "scan_to_any_inet":
3408+
result, ok := out.(net.IP)
3409+
if !ok {
3410+
t.Fatal("expected net.IP, got", out)
3411+
}
3412+
expected, ok := dt.expectedVal.(net.IP)
3413+
if !ok {
3414+
t.Fatal("expected net.IP, got", dt.expectedVal)
3415+
}
3416+
if result.String() != expected.String() {
3417+
t.Fatalf("expected %v, got %v", expected, out)
3418+
}
3419+
case "scan_to_any_date":
3420+
result, ok := out.(time.Time)
3421+
if !ok {
3422+
t.Fatal("expected time.Time, got", out)
3423+
}
3424+
expected, ok := dt.expectedVal.(time.Time)
3425+
if !ok {
3426+
t.Fatal("expected time.Time, got", dt.expectedVal)
3427+
}
3428+
if result.String() != expected.String() {
3429+
t.Fatalf("expected %v, got %v", expected, out)
3430+
}
3431+
case "scan_to_any_duration":
3432+
result, ok := out.(Duration)
3433+
if !ok {
3434+
t.Fatal("expected time.Duration, got", out)
3435+
}
3436+
expected, ok := dt.expectedVal.(Duration)
3437+
if !ok {
3438+
t.Fatal("expected time.Duration, got", dt.expectedVal)
3439+
}
3440+
if result != expected {
3441+
t.Fatalf("expected %v, got %v", expected, out)
3442+
}
3443+
default:
3444+
if out != dt.expectedVal {
3445+
t.Fatalf("expected %v, got %v", dt.expectedVal, out)
3446+
}
3447+
}
3448+
})
3449+
}
3450+
}

marshal.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,23 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
222222
// date | *time.Time | time of beginning of the day (in UTC)
223223
// date | *string | formatted with 2006-01-02 format
224224
// duration | *gocql.Duration |
225+
//
226+
// Scan into interface{} implemented by unmarshal into default type:
227+
//
228+
// CQL type | Go type
229+
// Varchar | string
230+
// Varint | bigInt
231+
// IntLike | int
232+
// Boolean | bool
233+
// Float | float32
234+
// Double | float64
235+
// Decimal | infDec
236+
// Time | time.Duration
237+
// Timestamp | time.Time
238+
// Date | time.Time
239+
// Duration | Duration
240+
// UUID | string
241+
// Inet | net.IP
225242
func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
226243
if v, ok := value.(Unmarshaler); ok {
227244
return v.UnmarshalCQL(info, data)
@@ -350,6 +367,9 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error {
350367
*v = nil
351368
}
352369
return nil
370+
case *interface{}:
371+
*v = string(data)
372+
return nil
353373
}
354374

355375
rv := reflect.ValueOf(value)
@@ -743,6 +763,8 @@ func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error {
743763
*v = bytesToUint64(data[1:])
744764
return nil
745765
}
766+
case *interface{}:
767+
return unmarshalBigInt(info, data, value)
746768
}
747769

748770
if len(data) > 8 {
@@ -904,6 +926,12 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
904926
case *string:
905927
*v = strconv.FormatInt(int64Val, 10)
906928
return nil
929+
case *interface{}:
930+
if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) {
931+
return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, info.Type())
932+
}
933+
*v = int(int64Val)
934+
return nil
907935
}
908936

909937
rv := reflect.ValueOf(value)
@@ -1055,6 +1083,9 @@ func unmarshalBool(info TypeInfo, data []byte, value interface{}) error {
10551083
case *bool:
10561084
*v = decBool(data)
10571085
return nil
1086+
case *interface{}:
1087+
*v = decBool(data)
1088+
return nil
10581089
}
10591090
rv := reflect.ValueOf(value)
10601091
if rv.Kind() != reflect.Ptr {
@@ -1105,6 +1136,9 @@ func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error {
11051136
case *float32:
11061137
*v = math.Float32frombits(uint32(decInt(data)))
11071138
return nil
1139+
case *interface{}:
1140+
*v = math.Float32frombits(uint32(decInt(data)))
1141+
return nil
11081142
}
11091143
rv := reflect.ValueOf(value)
11101144
if rv.Kind() != reflect.Ptr {
@@ -1146,6 +1180,9 @@ func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error {
11461180
case *float64:
11471181
*v = math.Float64frombits(uint64(decBigInt(data)))
11481182
return nil
1183+
case *interface{}:
1184+
*v = math.Float64frombits(uint64(decBigInt(data)))
1185+
return nil
11491186
}
11501187
rv := reflect.ValueOf(value)
11511188
if rv.Kind() != reflect.Ptr {
@@ -1196,6 +1233,14 @@ func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error {
11961233
unscaled := decBigInt2C(data[4:], nil)
11971234
*v = *inf.NewDecBig(unscaled, inf.Scale(scale))
11981235
return nil
1236+
case *interface{}:
1237+
if len(data) < 4 {
1238+
return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data))
1239+
}
1240+
scale := decInt(data[0:4])
1241+
unscaled := decBigInt2C(data[4:], nil)
1242+
*v = *inf.NewDecBig(unscaled, inf.Scale(scale))
1243+
return nil
11991244
}
12001245
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
12011246
}
@@ -1302,6 +1347,9 @@ func unmarshalTime(info TypeInfo, data []byte, value interface{}) error {
13021347
case *time.Duration:
13031348
*v = time.Duration(decBigInt(data))
13041349
return nil
1350+
case *interface{}:
1351+
*v = time.Duration(decBigInt(data))
1352+
return nil
13051353
}
13061354

13071355
rv := reflect.ValueOf(value)
@@ -1334,6 +1382,16 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error {
13341382
nsec := (x - sec*1000) * 1000000
13351383
*v = time.Unix(sec, nsec).In(time.UTC)
13361384
return nil
1385+
case *interface{}:
1386+
if len(data) == 0 {
1387+
*v = time.Time{}
1388+
return nil
1389+
}
1390+
x := decBigInt(data)
1391+
sec := x / 1000
1392+
nsec := (x - sec*1000) * 1000000
1393+
*v = time.Unix(sec, nsec).In(time.UTC)
1394+
return nil
13371395
}
13381396

13391397
rv := reflect.ValueOf(value)
@@ -1419,6 +1477,16 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error {
14191477
timestamp := (int64(current) - int64(origin)) * millisecondsInADay
14201478
*v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02")
14211479
return nil
1480+
case *interface{}:
1481+
if len(data) == 0 {
1482+
*v = time.Time{}
1483+
return nil
1484+
}
1485+
var origin uint32 = 1 << 31
1486+
var current uint32 = binary.BigEndian.Uint32(data)
1487+
timestamp := (int64(current) - int64(origin)) * millisecondsInADay
1488+
*v = time.UnixMilli(timestamp).In(time.UTC)
1489+
return nil
14221490
}
14231491
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
14241492
}
@@ -1478,6 +1546,25 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error {
14781546
Nanoseconds: nanos,
14791547
}
14801548
return nil
1549+
case *interface{}:
1550+
if len(data) == 0 {
1551+
*v = Duration{
1552+
Months: 0,
1553+
Days: 0,
1554+
Nanoseconds: 0,
1555+
}
1556+
return nil
1557+
}
1558+
months, days, nanos, err := decVints(data)
1559+
if err != nil {
1560+
return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error())
1561+
}
1562+
*v = Duration{
1563+
Months: months,
1564+
Days: days,
1565+
Nanoseconds: nanos,
1566+
}
1567+
return nil
14811568
}
14821569
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
14831570
}
@@ -1914,6 +2001,9 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
19142001
case *[]byte:
19152002
*v = u[:]
19162003
return nil
2004+
case *interface{}:
2005+
*v = u.String()
2006+
return nil
19172007
}
19182008
return unmarshalErrorf("can not unmarshal X %s into %T", info, value)
19192009
}
@@ -1996,6 +2086,18 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error {
19962086
}
19972087
*v = ip.String()
19982088
return nil
2089+
case *interface{}:
2090+
if x := len(data); !(x == 4 || x == 16) {
2091+
return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x)
2092+
}
2093+
buf := copyBytes(data)
2094+
ip := net.IP(buf)
2095+
if v4 := ip.To4(); v4 != nil {
2096+
*v = v4
2097+
return nil
2098+
}
2099+
*v = ip
2100+
return nil
19992101
}
20002102
return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
20012103
}

0 commit comments

Comments
 (0)