/src/duckdb/src/function/scalar/struct/struct_concat.cpp
Line | Count | Source (jump to first uncovered line) |
1 | | #include "duckdb/function/scalar/nested_functions.hpp" |
2 | | #include "duckdb/function/scalar/struct_functions.hpp" |
3 | | #include "duckdb/common/case_insensitive_map.hpp" |
4 | | #include "duckdb/planner/expression_binder.hpp" |
5 | | #include "duckdb/planner/expression/bound_function_expression.hpp" |
6 | | #include "duckdb/storage/statistics/struct_stats.hpp" |
7 | | |
8 | | namespace duckdb { |
9 | | |
10 | 0 | static void StructConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { |
11 | 0 | auto &result_cols = StructVector::GetEntries(result); |
12 | 0 | idx_t offset = 0; |
13 | |
|
14 | 0 | if (!args.AllConstant()) { |
15 | | // Unless all arguments are constant, we flatten the input to make sure it's homogeneous |
16 | 0 | args.Flatten(); |
17 | 0 | } |
18 | |
|
19 | 0 | for (auto &arg : args.data) { |
20 | 0 | const auto &child_cols = StructVector::GetEntries(arg); |
21 | 0 | for (auto &child_col : child_cols) { |
22 | 0 | result_cols[offset++]->Reference(*child_col); |
23 | 0 | } |
24 | 0 | } |
25 | 0 | D_ASSERT(offset == result_cols.size()); |
26 | |
|
27 | 0 | if (args.AllConstant()) { |
28 | 0 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
29 | 0 | } |
30 | |
|
31 | 0 | result.Verify(args.size()); |
32 | 0 | } |
33 | | |
34 | | static unique_ptr<FunctionData> StructConcatBind(ClientContext &context, ScalarFunction &bound_function, |
35 | 0 | vector<unique_ptr<Expression>> &arguments) { |
36 | | |
37 | | // collect names and deconflict, construct return type |
38 | 0 | if (arguments.empty()) { |
39 | 0 | throw InvalidInputException("struct_concat: At least one argument is required"); |
40 | 0 | } |
41 | | |
42 | 0 | child_list_t<LogicalType> combined_children; |
43 | 0 | case_insensitive_set_t name_set; |
44 | |
|
45 | 0 | bool has_unnamed = false; |
46 | |
|
47 | 0 | for (idx_t arg_idx = 0; arg_idx < arguments.size(); arg_idx++) { |
48 | 0 | const auto &arg = arguments[arg_idx]; |
49 | |
|
50 | 0 | if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { |
51 | 0 | throw ParameterNotResolvedException(); |
52 | 0 | } |
53 | | |
54 | 0 | if (arg->return_type.id() != LogicalTypeId::STRUCT) { |
55 | 0 | throw InvalidInputException("struct_concat: Argument at position \"%d\" is not a STRUCT", arg_idx + 1); |
56 | 0 | } |
57 | | |
58 | 0 | const auto &child_types = StructType::GetChildTypes(arg->return_type); |
59 | 0 | for (const auto &child : child_types) { |
60 | 0 | if (!child.first.empty()) { |
61 | 0 | auto it = name_set.find(child.first); |
62 | 0 | if (it != name_set.end()) { |
63 | 0 | if (*it == child.first) { |
64 | 0 | throw InvalidInputException("struct_concat: Arguments contain duplicate STRUCT entry \"%s\"", |
65 | 0 | child.first); |
66 | 0 | } |
67 | 0 | throw InvalidInputException( |
68 | 0 | "struct_concat: Arguments contain case-insensitive duplicate STRUCT entry \"%s\" and \"%s\"", |
69 | 0 | child.first, *it); |
70 | 0 | } |
71 | 0 | name_set.insert(child.first); |
72 | 0 | } else { |
73 | 0 | has_unnamed = true; |
74 | 0 | } |
75 | 0 | combined_children.push_back(child); |
76 | 0 | } |
77 | 0 | } |
78 | | |
79 | 0 | if (has_unnamed && !name_set.empty()) { |
80 | 0 | throw InvalidInputException("struct_concat: Cannot mix named and unnamed STRUCTs"); |
81 | 0 | } |
82 | | |
83 | 0 | bound_function.return_type = LogicalType::STRUCT(combined_children); |
84 | 0 | return nullptr; |
85 | 0 | } |
86 | | |
87 | 0 | static unique_ptr<BaseStatistics> StructConcatStats(ClientContext &context, FunctionStatisticsInput &input) { |
88 | 0 | const auto &expr = input.expr; |
89 | |
|
90 | 0 | auto &arg_stats = input.child_stats; |
91 | 0 | auto &arg_exprs = input.expr.children; |
92 | |
|
93 | 0 | auto struct_stats = StructStats::CreateUnknown(expr.return_type); |
94 | 0 | idx_t struct_index = 0; |
95 | |
|
96 | 0 | for (idx_t arg_idx = 0; arg_idx < arg_exprs.size(); arg_idx++) { |
97 | 0 | auto &arg_stat = arg_stats[arg_idx]; |
98 | 0 | auto &arg_type = arg_exprs[arg_idx]->return_type; |
99 | 0 | for (idx_t child_idx = 0; child_idx < StructType::GetChildCount(arg_type); child_idx++) { |
100 | 0 | auto &child_stat = StructStats::GetChildStats(arg_stat, child_idx); |
101 | 0 | StructStats::SetChildStats(struct_stats, struct_index++, child_stat); |
102 | 0 | } |
103 | 0 | } |
104 | 0 | return struct_stats.ToUnique(); |
105 | 0 | } |
106 | | |
107 | 8.67k | ScalarFunction StructConcatFun::GetFunction() { |
108 | 8.67k | ScalarFunction fun("struct_concat", {}, LogicalTypeId::STRUCT, StructConcatFunction, StructConcatBind, nullptr, |
109 | 8.67k | StructConcatStats); |
110 | 8.67k | fun.varargs = LogicalType::ANY; |
111 | 8.67k | fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; |
112 | 8.67k | return fun; |
113 | 8.67k | } |
114 | | |
115 | | } // namespace duckdb |