forked from ClickHouse/ClickHouse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharrayFlatten.cpp
130 lines (98 loc) · 4.15 KB
/
arrayFlatten.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeArray.h>
#include <Columns/ColumnArray.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
}
/// arrayFlatten([[1, 2, 3], [4, 5]]) = [1, 2, 3, 4, 5] - flatten array.
class ArrayFlatten : public IFunction
{
public:
static constexpr auto name = "arrayFlatten";
static FunctionPtr create(ContextPtr) { return std::make_shared<ArrayFlatten>(); }
size_t getNumberOfArguments() const override { return 1; }
bool useDefaultImplementationForConstants() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isArray(arguments[0]))
throw Exception("Illegal type " + arguments[0]->getName() +
" of argument of function " + getName() +
", expected Array", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
DataTypePtr nested_type = arguments[0];
while (isArray(nested_type))
nested_type = checkAndGetDataType<DataTypeArray>(nested_type.get())->getNestedType();
return std::make_shared<DataTypeArray>(nested_type);
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
/** We create an array column with array elements as the most deep elements of nested arrays,
* and construct offsets by selecting elements of most deep offsets by values of ancestor offsets.
*
Example 1:
Source column: Array(Array(UInt8)):
Row 1: [[1, 2, 3], [4, 5]], Row 2: [[6], [7, 8]]
data: [1, 2, 3], [4, 5], [6], [7, 8]
offsets: 2, 4
data.data: 1 2 3 4 5 6 7 8
data.offsets: 3 5 6 8
Result column: Array(UInt8):
Row 1: [1, 2, 3, 4, 5], Row 2: [6, 7, 8]
data: 1 2 3 4 5 6 7 8
offsets: 5 8
Result offsets are selected from the most deep (data.offsets) by previous deep (offsets) (and values are decremented by one):
3 5 6 8
^ ^
Example 2:
Source column: Array(Array(Array(UInt8))):
Row 1: [[], [[1], [], [2, 3]]], Row 2: [[[4]]]
most deep data: 1 2 3 4
offsets1: 2 3
offsets2: 0 3 4
- ^ ^ - select by prev offsets
offsets3: 1 1 3 4
- ^ ^ - select by prev offsets
result offsets: 3, 4
result: Row 1: [1, 2, 3], Row2: [4]
*/
const ColumnArray * src_col = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
if (!src_col)
throw Exception("Illegal column " + arguments[0].column->getName() + " in argument of function 'arrayFlatten'",
ErrorCodes::ILLEGAL_COLUMN);
const IColumn::Offsets & src_offsets = src_col->getOffsets();
ColumnArray::ColumnOffsets::MutablePtr result_offsets_column;
const IColumn::Offsets * prev_offsets = &src_offsets;
const IColumn * prev_data = &src_col->getData();
while (const ColumnArray * next_col = checkAndGetColumn<ColumnArray>(prev_data))
{
if (!result_offsets_column)
result_offsets_column = ColumnArray::ColumnOffsets::create(input_rows_count);
IColumn::Offsets & result_offsets = result_offsets_column->getData();
const IColumn::Offsets * next_offsets = &next_col->getOffsets();
for (size_t i = 0; i < input_rows_count; ++i)
result_offsets[i] = (*next_offsets)[(*prev_offsets)[i] - 1]; /// -1 array subscript is Ok, see PaddedPODArray
prev_offsets = &result_offsets;
prev_data = &next_col->getData();
}
return ColumnArray::create(
prev_data->getPtr(),
result_offsets_column ? std::move(result_offsets_column) : src_col->getOffsetsPtr());
}
private:
String getName() const override
{
return name;
}
};
REGISTER_FUNCTION(ArrayFlatten)
{
factory.registerFunction<ArrayFlatten>();
factory.registerAlias("flatten", "arrayFlatten", FunctionFactory::CaseInsensitive);
}
}