/src/duckdb/src/main/extension/extension_loader.cpp
Line | Count | Source |
1 | | #include "duckdb/main/extension/extension_loader.hpp" |
2 | | |
3 | | #include "duckdb/function/scalar_function.hpp" |
4 | | #include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" |
5 | | #include "duckdb/parser/parsed_data/create_type_info.hpp" |
6 | | #include "duckdb/parser/parsed_data/create_copy_function_info.hpp" |
7 | | #include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" |
8 | | #include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" |
9 | | #include "duckdb/parser/parsed_data/create_table_function_info.hpp" |
10 | | #include "duckdb/parser/parsed_data/create_macro_info.hpp" |
11 | | #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" |
12 | | #include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" |
13 | | #include "duckdb/parser/parsed_data/create_collation_info.hpp" |
14 | | #include "duckdb/main/extension_install_info.hpp" |
15 | | #include "duckdb/catalog/catalog.hpp" |
16 | | #include "duckdb/main/config.hpp" |
17 | | #include "duckdb/main/secret/secret_manager.hpp" |
18 | | #include "duckdb/main/database.hpp" |
19 | | |
20 | | namespace duckdb { |
21 | | |
22 | | ExtensionLoader::ExtensionLoader(ExtensionActiveLoad &load_info) |
23 | 0 | : db(load_info.db), extension_name(load_info.extension_name), extension_info(load_info.info) { |
24 | 0 | } |
25 | | |
26 | 0 | ExtensionLoader::ExtensionLoader(DatabaseInstance &db, const string &name) : db(db), extension_name(name) { |
27 | 0 | } |
28 | | |
29 | 0 | DatabaseInstance &ExtensionLoader::GetDatabaseInstance() { |
30 | 0 | return db; |
31 | 0 | } |
32 | | |
33 | 0 | void ExtensionLoader::SetDescription(const string &description) { |
34 | 0 | extension_description = description; |
35 | 0 | } |
36 | | |
37 | 0 | void ExtensionLoader::FinalizeLoad() { |
38 | | // Set extension description, if provided |
39 | 0 | if (!extension_description.empty() && extension_info) { |
40 | 0 | auto info = make_uniq<ExtensionLoadedInfo>(); |
41 | 0 | info->description = extension_description; |
42 | 0 | extension_info->load_info = std::move(info); |
43 | 0 | } |
44 | 0 | } |
45 | | |
46 | 0 | void ExtensionLoader::RegisterFunction(ScalarFunction function) { |
47 | 0 | ScalarFunctionSet set(function.name); |
48 | 0 | set.AddFunction(std::move(function)); |
49 | 0 | RegisterFunction(std::move(set)); |
50 | 0 | } |
51 | | |
52 | 0 | void ExtensionLoader::RegisterFunction(ScalarFunctionSet function) { |
53 | 0 | CreateScalarFunctionInfo info(std::move(function)); |
54 | 0 | info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; |
55 | 0 | RegisterFunction(std::move(info)); |
56 | 0 | } |
57 | | |
58 | 0 | void ExtensionLoader::RegisterFunction(CreateScalarFunctionInfo function) { |
59 | 0 | D_ASSERT(!function.functions.name.empty()); |
60 | 0 | auto &system_catalog = Catalog::GetSystemCatalog(db); |
61 | 0 | auto data = CatalogTransaction::GetSystemTransaction(db); |
62 | 0 | system_catalog.CreateFunction(data, function); |
63 | 0 | } |
64 | | |
65 | 0 | void ExtensionLoader::RegisterFunction(AggregateFunction function) { |
66 | 0 | AggregateFunctionSet set(function.name); |
67 | 0 | set.AddFunction(std::move(function)); |
68 | 0 | RegisterFunction(std::move(set)); |
69 | 0 | } |
70 | | |
71 | 0 | void ExtensionLoader::RegisterFunction(AggregateFunctionSet function) { |
72 | 0 | CreateAggregateFunctionInfo info(std::move(function)); |
73 | 0 | info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; |
74 | 0 | RegisterFunction(std::move(info)); |
75 | 0 | } |
76 | | |
77 | 0 | void ExtensionLoader::RegisterFunction(CreateAggregateFunctionInfo function) { |
78 | 0 | D_ASSERT(!function.functions.name.empty()); |
79 | 0 | auto &system_catalog = Catalog::GetSystemCatalog(db); |
80 | 0 | auto data = CatalogTransaction::GetSystemTransaction(db); |
81 | 0 | system_catalog.CreateFunction(data, function); |
82 | 0 | } |
83 | | |
84 | 0 | void ExtensionLoader::RegisterFunction(CreateSecretFunction function) { |
85 | 0 | D_ASSERT(!function.secret_type.empty()); |
86 | 0 | auto &config = DBConfig::GetConfig(db); |
87 | 0 | config.secret_manager->RegisterSecretFunction(std::move(function), OnCreateConflict::ERROR_ON_CONFLICT); |
88 | 0 | } |
89 | | |
90 | 0 | void ExtensionLoader::RegisterFunction(TableFunction function) { |
91 | 0 | TableFunctionSet set(function.name); |
92 | 0 | set.AddFunction(std::move(function)); |
93 | 0 | RegisterFunction(std::move(set)); |
94 | 0 | } |
95 | | |
96 | 0 | void ExtensionLoader::RegisterFunction(TableFunctionSet function) { |
97 | 0 | D_ASSERT(!function.name.empty()); |
98 | 0 | CreateTableFunctionInfo info(std::move(function)); |
99 | 0 | info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; |
100 | 0 | RegisterFunction(std::move(info)); |
101 | 0 | } |
102 | | |
103 | 0 | void ExtensionLoader::RegisterFunction(CreateTableFunctionInfo info) { |
104 | 0 | D_ASSERT(!info.functions.name.empty()); |
105 | 0 | auto &system_catalog = Catalog::GetSystemCatalog(db); |
106 | 0 | auto data = CatalogTransaction::GetSystemTransaction(db); |
107 | 0 | system_catalog.CreateFunction(data, info); |
108 | 0 | } |
109 | | |
110 | 0 | void ExtensionLoader::RegisterFunction(PragmaFunction function) { |
111 | 0 | D_ASSERT(!function.name.empty()); |
112 | 0 | PragmaFunctionSet set(function.name); |
113 | 0 | set.AddFunction(std::move(function)); |
114 | 0 | RegisterFunction(std::move(set)); |
115 | 0 | } |
116 | | |
117 | 0 | void ExtensionLoader::RegisterFunction(PragmaFunctionSet function) { |
118 | 0 | D_ASSERT(!function.name.empty()); |
119 | 0 | auto function_name = function.name; |
120 | 0 | CreatePragmaFunctionInfo info(std::move(function_name), std::move(function)); |
121 | 0 | auto &system_catalog = Catalog::GetSystemCatalog(db); |
122 | 0 | auto data = CatalogTransaction::GetSystemTransaction(db); |
123 | 0 | system_catalog.CreatePragmaFunction(data, info); |
124 | 0 | } |
125 | | |
126 | 0 | void ExtensionLoader::RegisterFunction(CopyFunction function) { |
127 | 0 | CreateCopyFunctionInfo info(std::move(function)); |
128 | 0 | auto &system_catalog = Catalog::GetSystemCatalog(db); |
129 | 0 | auto data = CatalogTransaction::GetSystemTransaction(db); |
130 | 0 | system_catalog.CreateCopyFunction(data, info); |
131 | 0 | } |
132 | | |
133 | 0 | void ExtensionLoader::RegisterFunction(CreateMacroInfo &info) { |
134 | 0 | auto &system_catalog = Catalog::GetSystemCatalog(db); |
135 | 0 | auto data = CatalogTransaction::GetSystemTransaction(db); |
136 | 0 | system_catalog.CreateFunction(data, info); |
137 | 0 | } |
138 | | |
139 | 0 | void ExtensionLoader::RegisterCollation(CreateCollationInfo &info) { |
140 | 0 | auto &system_catalog = Catalog::GetSystemCatalog(db); |
141 | 0 | auto data = CatalogTransaction::GetSystemTransaction(db); |
142 | 0 | info.on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; |
143 | 0 | system_catalog.CreateCollation(data, info); |
144 | | |
145 | | // Also register as a function for serialisation |
146 | 0 | CreateScalarFunctionInfo finfo(info.function); |
147 | 0 | finfo.on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; |
148 | 0 | system_catalog.CreateFunction(data, finfo); |
149 | 0 | } |
150 | | |
151 | 0 | void ExtensionLoader::AddFunctionOverload(ScalarFunction function) { |
152 | 0 | auto &scalar_function = GetFunction(function.name); |
153 | 0 | scalar_function.functions.AddFunction(std::move(function)); |
154 | 0 | } |
155 | | |
156 | 0 | void ExtensionLoader::AddFunctionOverload(ScalarFunctionSet functions) { // NOLINT |
157 | 0 | D_ASSERT(!functions.name.empty()); |
158 | 0 | auto &scalar_function = GetFunction(functions.name); |
159 | 0 | for (auto &function : functions.functions) { |
160 | 0 | function.name = functions.name; |
161 | 0 | scalar_function.functions.AddFunction(std::move(function)); |
162 | 0 | } |
163 | 0 | } |
164 | | |
165 | 0 | void ExtensionLoader::AddFunctionOverload(TableFunctionSet functions) { // NOLINT |
166 | 0 | auto &table_function = GetTableFunction(functions.name); |
167 | 0 | for (auto &function : functions.functions) { |
168 | 0 | function.name = functions.name; |
169 | 0 | table_function.functions.AddFunction(std::move(function)); |
170 | 0 | } |
171 | 0 | } |
172 | | |
173 | 0 | static optional_ptr<CatalogEntry> TryGetEntry(DatabaseInstance &db, const string &name, CatalogType type) { |
174 | 0 | D_ASSERT(!name.empty()); |
175 | 0 | auto &system_catalog = Catalog::GetSystemCatalog(db); |
176 | 0 | auto data = CatalogTransaction::GetSystemTransaction(db); |
177 | 0 | auto &schema = system_catalog.GetSchema(data, DEFAULT_SCHEMA); |
178 | 0 | return schema.GetEntry(data, type, name); |
179 | 0 | } |
180 | | |
181 | 0 | optional_ptr<CatalogEntry> ExtensionLoader::TryGetFunction(const string &name) { |
182 | 0 | return TryGetEntry(db, name, CatalogType::SCALAR_FUNCTION_ENTRY); |
183 | 0 | } |
184 | | |
185 | 0 | ScalarFunctionCatalogEntry &ExtensionLoader::GetFunction(const string &name) { |
186 | 0 | auto catalog_entry = TryGetFunction(name); |
187 | 0 | if (!catalog_entry) { |
188 | 0 | throw InvalidInputException("Function with name \"%s\" not found in ExtensionLoader::GetFunction", name); |
189 | 0 | } |
190 | 0 | return catalog_entry->Cast<ScalarFunctionCatalogEntry>(); |
191 | 0 | } |
192 | | |
193 | 0 | optional_ptr<CatalogEntry> ExtensionLoader::TryGetTableFunction(const string &name) { |
194 | 0 | return TryGetEntry(db, name, CatalogType::TABLE_FUNCTION_ENTRY); |
195 | 0 | } |
196 | | |
197 | 0 | TableFunctionCatalogEntry &ExtensionLoader::GetTableFunction(const string &name) { |
198 | 0 | auto catalog_entry = TryGetTableFunction(name); |
199 | 0 | if (!catalog_entry) { |
200 | 0 | throw InvalidInputException("Function with name \"%s\" not found in ExtensionLoader::GetTableFunction", name); |
201 | 0 | } |
202 | 0 | return catalog_entry->Cast<TableFunctionCatalogEntry>(); |
203 | 0 | } |
204 | | |
205 | 0 | void ExtensionLoader::RegisterType(string type_name, LogicalType type, bind_logical_type_function_t bind_modifiers) { |
206 | 0 | D_ASSERT(!type_name.empty()); |
207 | 0 | CreateTypeInfo info(std::move(type_name), std::move(type), bind_modifiers); |
208 | 0 | info.temporary = true; |
209 | 0 | info.internal = true; |
210 | 0 | auto &system_catalog = Catalog::GetSystemCatalog(db); |
211 | 0 | auto data = CatalogTransaction::GetSystemTransaction(db); |
212 | 0 | system_catalog.CreateType(data, info); |
213 | 0 | } |
214 | | |
215 | 0 | void ExtensionLoader::RegisterSecretType(SecretType secret_type) { |
216 | 0 | auto &config = DBConfig::GetConfig(db); |
217 | 0 | config.secret_manager->RegisterSecretType(secret_type); |
218 | 0 | } |
219 | | |
220 | | void ExtensionLoader::RegisterCastFunction(const LogicalType &source, const LogicalType &target, |
221 | 0 | bind_cast_function_t bind_function, int64_t implicit_cast_cost) { |
222 | 0 | auto &config = DBConfig::GetConfig(db); |
223 | 0 | auto &casts = config.GetCastFunctions(); |
224 | 0 | casts.RegisterCastFunction(source, target, bind_function, implicit_cast_cost); |
225 | 0 | } |
226 | | |
227 | | void ExtensionLoader::RegisterCastFunction(const LogicalType &source, const LogicalType &target, BoundCastInfo function, |
228 | 0 | int64_t implicit_cast_cost) { |
229 | 0 | auto &config = DBConfig::GetConfig(db); |
230 | 0 | auto &casts = config.GetCastFunctions(); |
231 | 0 | casts.RegisterCastFunction(source, target, std::move(function), implicit_cast_cost); |
232 | 0 | } |
233 | | |
234 | | } // namespace duckdb |