iceberg-cpp
Loading...
Searching...
No Matches
projection_util_internal.h
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#pragma once
21
22#include <algorithm>
23#include <memory>
24#include <ranges>
25#include <string>
26#include <string_view>
27#include <unordered_set>
28#include <utility>
29
30#include "iceberg/expression/literal.h"
33#include "iceberg/result.h"
34#include "iceberg/transform.h"
38#include "iceberg/util/macros.h"
39#include "iceberg/util/string_util.h"
40
41namespace iceberg {
42
44 private:
45 static Result<Literal> AdjustLiteral(const Literal& literal, int adjustment) {
46 switch (literal.type()->type_id()) {
47 case TypeId::kInt:
48 return Literal::Int(std::get<int32_t>(literal.value()) + adjustment);
49 case TypeId::kLong:
50 return Literal::Long(std::get<int64_t>(literal.value()) + adjustment);
51 case TypeId::kDate:
52 return Literal::Date(std::get<int32_t>(literal.value()) + adjustment);
53 case TypeId::kTimestamp:
54 return Literal::Timestamp(std::get<int64_t>(literal.value()) + adjustment);
55 case TypeId::kTimestampTz:
56 return Literal::TimestampTz(std::get<int64_t>(literal.value()) + adjustment);
57 case TypeId::kDecimal: {
58 const auto& decimal_type =
59 internal::checked_cast<const DecimalType&>(*literal.type());
60 Decimal adjusted = std::get<Decimal>(literal.value()) + Decimal(adjustment);
61 return Literal::Decimal(adjusted.value(), decimal_type.precision(),
62 decimal_type.scale());
63 }
64 default:
65 return NotSupported("{} is not a valid literal type for value adjustment",
66 literal.type()->ToString());
67 }
68 }
69
70 static Result<Literal> PlusOne(const Literal& literal) {
71 return AdjustLiteral(literal, /*adjustment=*/+1);
72 }
73
74 static Result<Literal> MinusOne(const Literal& literal) {
75 return AdjustLiteral(literal, /*adjustment=*/-1);
76 }
77
78 static Result<std::unique_ptr<UnboundPredicate>> MakePredicate(
79 Expression::Operation op, std::string_view name,
80 const std::shared_ptr<TransformFunction>& func, const Literal& literal) {
81 ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name)));
82 ICEBERG_ASSIGN_OR_RAISE(auto lit, func->Transform(literal));
83 return UnboundPredicateImpl<BoundReference>::Make(op, std::move(ref), std::move(lit));
84 }
85
86 static Result<std::unique_ptr<UnboundPredicate>> TransformSet(
87 std::string_view name, const std::shared_ptr<BoundSetPredicate>& pred,
88 const std::shared_ptr<TransformFunction>& func) {
89 std::vector<Literal> transformed;
90 transformed.reserve(pred->literal_set().size());
91 for (const auto& lit : pred->literal_set()) {
92 ICEBERG_ASSIGN_OR_RAISE(auto transformed_lit, func->Transform(lit));
93 transformed.push_back(std::move(transformed_lit));
94 }
95 ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name)));
96 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref),
97 std::move(transformed));
98 }
99
100 static Result<std::unique_ptr<UnboundPredicate>> TruncateByteArray(
101 std::string_view name, const std::shared_ptr<BoundLiteralPredicate>& pred,
102 const std::shared_ptr<TransformFunction>& func) {
103 switch (pred->op()) {
104 case Expression::Operation::kLt:
105 case Expression::Operation::kLtEq:
106 return MakePredicate(Expression::Operation::kLtEq, name, func, pred->literal());
107 case Expression::Operation::kGt:
108 case Expression::Operation::kGtEq:
109 return MakePredicate(Expression::Operation::kGtEq, name, func, pred->literal());
110 case Expression::Operation::kEq:
111 case Expression::Operation::kStartsWith:
112 return MakePredicate(pred->op(), name, func, pred->literal());
113 default:
114 return nullptr;
115 }
116 }
117
118 static Result<std::unique_ptr<UnboundPredicate>> TruncateByteArrayStrict(
119 std::string_view name, const std::shared_ptr<BoundLiteralPredicate>& pred,
120 const std::shared_ptr<TransformFunction>& func) {
121 switch (pred->op()) {
122 case Expression::Operation::kLt:
123 case Expression::Operation::kLtEq:
124 return MakePredicate(Expression::Operation::kLt, name, func, pred->literal());
125 case Expression::Operation::kGt:
126 case Expression::Operation::kGtEq:
127 return MakePredicate(Expression::Operation::kGt, name, func, pred->literal());
128 case Expression::Operation::kNotEq:
129 return MakePredicate(Expression::Operation::kNotEq, name, func, pred->literal());
130 default:
131 return nullptr;
132 }
133 }
134
135 // Apply to int32, int64, decimal, and temporal types
136 static Result<std::unique_ptr<UnboundPredicate>> TransformNumeric(
137 std::string_view name, const std::shared_ptr<BoundLiteralPredicate>& pred,
138 const std::shared_ptr<TransformFunction>& func) {
139 switch (func->source_type()->type_id()) {
140 case TypeId::kInt:
141 case TypeId::kLong:
142 case TypeId::kDecimal:
143 case TypeId::kDate:
144 case TypeId::kTimestamp:
145 case TypeId::kTimestampTz:
146 break;
147 default:
148 return NotSupported("{} is not a valid input type for numeric transform",
149 func->source_type()->ToString());
150 }
151
152 switch (pred->op()) {
153 case Expression::Operation::kLt: {
154 // adjust closed and then transform ltEq
155 ICEBERG_ASSIGN_OR_RAISE(auto adjusted, MinusOne(pred->literal()));
156 return MakePredicate(Expression::Operation::kLtEq, name, func, adjusted);
157 }
158 case Expression::Operation::kGt: {
159 // adjust closed and then transform gtEq
160 ICEBERG_ASSIGN_OR_RAISE(auto adjusted, PlusOne(pred->literal()));
161 return MakePredicate(Expression::Operation::kGtEq, name, func, adjusted);
162 }
163 case Expression::Operation::kLtEq:
164 case Expression::Operation::kGtEq:
165 case Expression::Operation::kEq:
166 return MakePredicate(pred->op(), name, func, pred->literal());
167 default:
168 return nullptr;
169 }
170 }
171
172 static Result<std::unique_ptr<UnboundPredicate>> TransformNumericStrict(
173 std::string_view name, const std::shared_ptr<BoundLiteralPredicate>& pred,
174 const std::shared_ptr<TransformFunction>& func) {
175 switch (func->source_type()->type_id()) {
176 case TypeId::kInt:
177 case TypeId::kLong:
178 case TypeId::kDecimal:
179 case TypeId::kDate:
180 case TypeId::kTimestamp:
181 case TypeId::kTimestampTz:
182 break;
183 default:
184 return NotSupported("{} is not a valid input type for numeric transform",
185 func->source_type()->ToString());
186 }
187
188 switch (pred->op()) {
189 case Expression::Operation::kLtEq: {
190 ICEBERG_ASSIGN_OR_RAISE(auto adjusted, PlusOne(pred->literal()));
191 return MakePredicate(Expression::Operation::kLt, name, func, adjusted);
192 }
193 case Expression::Operation::kGtEq: {
194 ICEBERG_ASSIGN_OR_RAISE(auto adjusted, MinusOne(pred->literal()));
195 return MakePredicate(Expression::Operation::kGt, name, func, adjusted);
196 }
197 case Expression::Operation::kLt:
198 case Expression::Operation::kGt:
199 case Expression::Operation::kNotEq:
200 return MakePredicate(pred->op(), name, func, pred->literal());
201 default:
202 return nullptr;
203 }
204 }
205
206 static Result<std::unique_ptr<UnboundPredicate>> TruncateStringLiteral(
207 std::string_view name, const std::shared_ptr<BoundLiteralPredicate>& pred,
208 const std::shared_ptr<TransformFunction>& func) {
209 const auto op = pred->op();
210 if (op != Expression::Operation::kStartsWith &&
211 op != Expression::Operation::kNotStartsWith) {
212 return TruncateByteArray(name, pred, func);
213 }
214
215 const auto& literal = pred->literal();
216 const auto length =
217 StringUtils::CodePointCount(std::get<std::string>(literal.value()));
218 const auto width = static_cast<size_t>(
219 internal::checked_pointer_cast<TruncateTransform>(func)->width());
220
221 if (length < width) {
222 return MakePredicate(op, name, func, literal);
223 }
224
225 if (length == width) {
226 if (op == Expression::Operation::kStartsWith) {
227 return MakePredicate(Expression::Operation::kEq, name, func, literal);
228 } else {
229 return MakePredicate(Expression::Operation::kNotEq, name, func, literal);
230 }
231 }
232
233 if (op == Expression::Operation::kStartsWith) {
234 return TruncateByteArray(name, pred, func);
235 }
236
237 return nullptr;
238 }
239
240 static Result<std::unique_ptr<UnboundPredicate>> TruncateStringLiteralStrict(
241 std::string_view name, const std::shared_ptr<BoundLiteralPredicate>& pred,
242 const std::shared_ptr<TransformFunction>& func) {
243 const auto op = pred->op();
244 if (op != Expression::Operation::kStartsWith &&
245 op != Expression::Operation::kNotStartsWith) {
246 return TruncateByteArrayStrict(name, pred, func);
247 }
248
249 const auto& literal = pred->literal();
250 const auto length =
251 StringUtils::CodePointCount(std::get<std::string>(literal.value()));
252 const auto width = static_cast<size_t>(
253 internal::checked_pointer_cast<TruncateTransform>(func)->width());
254
255 if (length < width) {
256 return MakePredicate(op, name, func, literal);
257 }
258
259 if (length == width) {
260 if (op == Expression::Operation::kStartsWith) {
261 return MakePredicate(Expression::Operation::kEq, name, func, literal);
262 } else {
263 return MakePredicate(Expression::Operation::kNotEq, name, func, literal);
264 }
265 }
266
267 if (op == Expression::Operation::kNotStartsWith) {
268 return MakePredicate(Expression::Operation::kNotStartsWith, name, func, literal);
269 }
270
271 return nullptr;
272 }
273
274 // Fixes an inclusive projection to account for incorrectly transformed values.
275 // align with Java implementation:
276 // https://github.com/apache/iceberg/blob/1.10.x/api/src/main/java/org/apache/iceberg/transforms/ProjectionUtil.java#L275
277 static Result<std::unique_ptr<UnboundPredicate>> FixInclusiveTimeProjection(
278 std::unique_ptr<UnboundPredicateImpl<BoundReference>> projected) {
279 if (projected == nullptr) {
280 return nullptr;
281 }
282
283 // adjust the predicate for values that were 1 larger than the correct transformed
284 // value
285 switch (projected->op()) {
286 case Expression::Operation::kLt: {
287 ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal");
288 const auto& literal = projected->literals().front();
289 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.value()),
290 "Expected int32_t");
291 if (auto value = std::get<int32_t>(literal.value()); value < 0) {
292 return UnboundPredicateImpl<BoundReference>::Make(Expression::Operation::kLt,
293 std::move(projected->term()),
294 Literal::Int(value + 1));
295 }
296
297 return projected;
298 }
299
300 case Expression::Operation::kLtEq: {
301 ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal");
302 const auto& literal = projected->literals().front();
303 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.value()),
304 "Expected int32_t");
305
306 if (auto value = std::get<int32_t>(literal.value()); value < 0) {
307 return UnboundPredicateImpl<BoundReference>::Make(Expression::Operation::kLtEq,
308 std::move(projected->term()),
309 Literal::Int(value + 1));
310 }
311 return projected;
312 }
313
314 case Expression::Operation::kGt:
315 case Expression::Operation::kGtEq:
316 // incorrect projected values are already greater than the bound for GT, GT_EQ
317 return projected;
318
319 case Expression::Operation::kEq: {
320 ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal");
321 const auto& literal = projected->literals().front();
322 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.value()),
323 "Expected int32_t");
324 if (auto value = std::get<int32_t>(literal.value()); value < 0) {
325 // match either the incorrect value (projectedValue + 1) or the correct value
326 // (projectedValue)
328 Expression::Operation::kIn, std::move(projected->term()),
329 {literal, Literal::Int(value + 1)});
330 }
331 return projected;
332 }
333
334 case Expression::Operation::kIn: {
335 ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal");
336 const auto& literals = projected->literals();
337 ICEBERG_DCHECK(
338 std::ranges::all_of(literals,
339 [](const auto& lit) {
340 return std::holds_alternative<int32_t>(lit.value());
341 }),
342 "Expected int32_t");
343 std::unordered_set<int32_t> value_set;
344 bool has_negative_value = false;
345 for (const auto& lit : literals) {
346 auto value = std::get<int32_t>(lit.value());
347 value_set.insert(value);
348 if (value < 0) {
349 value_set.insert(value + 1);
350 has_negative_value = true;
351 }
352 }
353 if (has_negative_value) {
354 auto values =
355 std::views::transform(value_set,
356 [](int32_t value) { return Literal::Int(value); }) |
357 std::ranges::to<std::vector>();
358 return UnboundPredicateImpl<BoundReference>::Make(Expression::Operation::kIn,
359 std::move(projected->term()),
360 std::move(values));
361 }
362 return projected;
363 }
364
365 case Expression::Operation::kNotIn:
366 case Expression::Operation::kNotEq:
367 // there is no inclusive projection for NOT_EQ and NOT_IN
368 return nullptr;
369
370 default:
371 return projected;
372 }
373 }
374
375 // Fixes a strict projection to account for incorrectly transformed values.
376 // align with Java implementation:
377 // https://github.com/apache/iceberg/blob/1.10.x/api/src/main/java/org/apache/iceberg/transforms/ProjectionUtil.java#L347
378 static Result<std::unique_ptr<UnboundPredicate>> FixStrictTimeProjection(
379 std::unique_ptr<UnboundPredicateImpl<BoundReference>> projected) {
380 if (projected == nullptr) {
381 return nullptr;
382 }
383
384 switch (projected->op()) {
385 case Expression::Operation::kLt:
386 case Expression::Operation::kLtEq:
387 // the correct bound is a correct strict projection for the incorrectly
388 // transformed values.
389 return projected;
390
391 case Expression::Operation::kGt: {
392 // GT and GT_EQ need to be adjusted because values that do not match the predicate
393 // may have been transformed into partition values that match the projected
394 // predicate.
395 ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal");
396 const auto& literal = projected->literals().front();
397 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.value()),
398 "Expected int32_t");
399 if (auto value = std::get<int32_t>(literal.value()); value <= 0) {
400 return UnboundPredicateImpl<BoundReference>::Make(Expression::Operation::kGt,
401 std::move(projected->term()),
402 Literal::Int(value + 1));
403 }
404 return projected;
405 }
406
407 case Expression::Operation::kGtEq: {
408 ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal");
409 const auto& literal = projected->literals().front();
410 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.value()),
411 "Expected int32_t");
412 if (auto value = std::get<int32_t>(literal.value()); value <= 0) {
413 return UnboundPredicateImpl<BoundReference>::Make(Expression::Operation::kGtEq,
414 std::move(projected->term()),
415 Literal::Int(value + 1));
416 }
417 return projected;
418 }
419
420 case Expression::Operation::kEq:
421 case Expression::Operation::kIn:
422 // there is no strict projection for EQ and IN
423 return nullptr;
424
425 case Expression::Operation::kNotEq: {
426 ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal");
427 const auto& literal = projected->literals().front();
428 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.value()),
429 "Expected int32_t");
430 if (auto value = std::get<int32_t>(literal.value()); value < 0) {
432 Expression::Operation::kNotIn, std::move(projected->term()),
433 {literal, Literal::Int(value + 1)});
434 }
435 return projected;
436 }
437
438 case Expression::Operation::kNotIn: {
439 ICEBERG_DCHECK(!projected->literals().empty(), "Expected at least one literal");
440 const auto& literals = projected->literals();
441 ICEBERG_DCHECK(
442 std::ranges::all_of(literals,
443 [](const auto& lit) {
444 return std::holds_alternative<int32_t>(lit.value());
445 }),
446 "Expected int32_t");
447 std::unordered_set<int32_t> value_set;
448 bool has_negative_value = false;
449 for (const auto& lit : literals) {
450 auto value = std::get<int32_t>(lit.value());
451 value_set.insert(value);
452 if (value < 0) {
453 value_set.insert(value + 1);
454 has_negative_value = true;
455 }
456 }
457 if (has_negative_value) {
458 auto values =
459 std::views::transform(value_set,
460 [](int32_t value) { return Literal::Int(value); }) |
461 std::ranges::to<std::vector>();
462 return UnboundPredicateImpl<BoundReference>::Make(Expression::Operation::kNotIn,
463 std::move(projected->term()),
464 std::move(values));
465 }
466 return projected;
467 }
468
469 default:
470 return nullptr;
471 }
472 }
473
474 public:
475 static Result<std::unique_ptr<UnboundPredicate>> IdentityProject(
476 std::string_view name, const std::shared_ptr<BoundPredicate>& pred) {
477 ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name)));
478 switch (pred->kind()) {
479 case BoundPredicate::Kind::kUnary: {
480 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref));
481 }
482 case BoundPredicate::Kind::kLiteral: {
483 const auto& literalPredicate =
484 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
485 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref),
486 literalPredicate->literal());
487 }
488 case BoundPredicate::Kind::kSet: {
489 const auto& setPredicate =
490 internal::checked_pointer_cast<BoundSetPredicate>(pred);
492 pred->op(), std::move(ref),
493 std::vector<Literal>(setPredicate->literal_set().begin(),
494 setPredicate->literal_set().end()));
495 }
496 }
497 std::unreachable();
498 }
499
500 static Result<std::unique_ptr<UnboundPredicate>> BucketProject(
501 std::string_view name, const std::shared_ptr<BoundPredicate>& pred,
502 const std::shared_ptr<TransformFunction>& func) {
503 ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name)));
504 switch (pred->kind()) {
505 case BoundPredicate::Kind::kUnary: {
506 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref));
507 }
508 case BoundPredicate::Kind::kLiteral: {
509 if (pred->op() == Expression::Operation::kEq) {
510 const auto& literalPredicate =
511 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
512 ICEBERG_ASSIGN_OR_RAISE(auto transformed,
513 func->Transform(literalPredicate->literal()));
514 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref),
515 std::move(transformed));
516 }
517 break;
518 }
519 case BoundPredicate::Kind::kSet: {
520 // notIn can't be projected
521 if (pred->op() == Expression::Operation::kIn) {
522 const auto& setPredicate =
523 internal::checked_pointer_cast<BoundSetPredicate>(pred);
524 return TransformSet(name, setPredicate, func);
525 }
526 break;
527 }
528 }
529
530 // comparison predicates can't be projected, notEq can't be projected
531 // TODO(anyone): small ranges can be projected.
532 // for example, (x > 0) and (x < 3) can be turned into in({1, 2}) and projected.
533 return nullptr;
534 }
535
536 static Result<std::unique_ptr<UnboundPredicate>> TruncateProject(
537 std::string_view name, const std::shared_ptr<BoundPredicate>& pred,
538 const std::shared_ptr<TransformFunction>& func) {
539 ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name)));
540 // Handle unary predicates uniformly for all types
541 if (pred->kind() == BoundPredicate::Kind::kUnary) {
542 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref));
543 }
544
545 // Handle set predicates (kIn) uniformly for all types
546 if (pred->kind() == BoundPredicate::Kind::kSet) {
547 if (pred->op() == Expression::Operation::kIn) {
548 const auto& setPredicate =
549 internal::checked_pointer_cast<BoundSetPredicate>(pred);
550 return TransformSet(name, setPredicate, func);
551 }
552 return nullptr;
553 }
554
555 // Handle literal predicates based on source type
556 const auto& literalPredicate =
557 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
558
559 switch (func->source_type()->type_id()) {
560 case TypeId::kInt:
561 case TypeId::kLong:
562 case TypeId::kDecimal:
563 return TransformNumeric(name, literalPredicate, func);
564 case TypeId::kString:
565 return TruncateStringLiteral(name, literalPredicate, func);
566 case TypeId::kBinary:
567 return TruncateByteArray(name, literalPredicate, func);
568 default:
569 return NotSupported("{} is not a valid input type for truncate transform",
570 func->source_type()->ToString());
571 }
572 }
573
574 static Result<std::unique_ptr<UnboundPredicate>> TemporalProject(
575 std::string_view name, const std::shared_ptr<BoundPredicate>& pred,
576 const std::shared_ptr<TransformFunction>& func) {
577 ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name)));
578 if (pred->kind() == BoundPredicate::Kind::kUnary) {
579 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref));
580 } else if (pred->kind() == BoundPredicate::Kind::kLiteral) {
581 const auto& literalPredicate =
582 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
583 ICEBERG_ASSIGN_OR_RAISE(auto projected,
584 TransformNumeric(name, literalPredicate, func));
585 if (func->transform_type() != TransformType::kDay ||
586 func->source_type()->type_id() != TypeId::kDate) {
587 return FixInclusiveTimeProjection(
588 internal::checked_pointer_cast<UnboundPredicateImpl<BoundReference>>(
589 std::move(projected)));
590 }
591 return projected;
592 } else if (pred->kind() == BoundPredicate::Kind::kSet &&
593 pred->op() == Expression::Operation::kIn) {
594 const auto& setPredicate = internal::checked_pointer_cast<BoundSetPredicate>(pred);
595 ICEBERG_ASSIGN_OR_RAISE(auto projected, TransformSet(name, setPredicate, func));
596 if (func->transform_type() != TransformType::kDay ||
597 func->source_type()->type_id() != TypeId::kDate) {
598 return FixInclusiveTimeProjection(
599 internal::checked_pointer_cast<UnboundPredicateImpl<BoundReference>>(
600 std::move(projected)));
601 }
602 return projected;
603 }
604
605 return nullptr;
606 }
607
608 static Result<std::unique_ptr<UnboundPredicate>> RemoveTransform(
609 std::string_view name, const std::shared_ptr<BoundPredicate>& pred) {
610 ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name)));
611 switch (pred->kind()) {
612 case BoundPredicate::Kind::kUnary: {
613 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref));
614 }
615 case BoundPredicate::Kind::kLiteral: {
616 const auto& literalPredicate =
617 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
618 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref),
619 literalPredicate->literal());
620 }
621 case BoundPredicate::Kind::kSet: {
622 const auto& setPredicate =
623 internal::checked_pointer_cast<BoundSetPredicate>(pred);
625 pred->op(), std::move(ref),
626 std::vector<Literal>(setPredicate->literal_set().begin(),
627 setPredicate->literal_set().end()));
628 }
629 }
630 std::unreachable();
631 }
632
633 static Result<std::unique_ptr<UnboundPredicate>> BucketProjectStrict(
634 std::string_view name, const std::shared_ptr<BoundPredicate>& pred,
635 const std::shared_ptr<TransformFunction>& func) {
636 ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name)));
637 switch (pred->kind()) {
638 case BoundPredicate::Kind::kUnary: {
639 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref));
640 }
641 case BoundPredicate::Kind::kLiteral: {
642 if (pred->op() == Expression::Operation::kNotEq) {
643 const auto& literalPredicate =
644 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
645 ICEBERG_ASSIGN_OR_RAISE(auto transformed,
646 func->Transform(literalPredicate->literal()));
647 // TODO(anyone): need to translate not(eq(...)) into notEq in expressions
648 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref),
649 std::move(transformed));
650 }
651 break;
652 }
653 case BoundPredicate::Kind::kSet: {
654 if (pred->op() == Expression::Operation::kNotIn) {
655 const auto& setPredicate =
656 internal::checked_pointer_cast<BoundSetPredicate>(pred);
657 return TransformSet(name, setPredicate, func);
658 }
659 break;
660 }
661 }
662
663 // no strict projection for comparison or equality
664 return nullptr;
665 }
666
667 static Result<std::unique_ptr<UnboundPredicate>> TruncateProjectStrict(
668 std::string_view name, const std::shared_ptr<BoundPredicate>& pred,
669 const std::shared_ptr<TransformFunction>& func) {
670 ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name)));
671 // Handle unary predicates uniformly for all types
672 if (pred->kind() == BoundPredicate::Kind::kUnary) {
673 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref));
674 }
675
676 // Handle set predicates (kNotIn) uniformly for all types
677 if (pred->kind() == BoundPredicate::Kind::kSet) {
678 if (pred->op() == Expression::Operation::kNotIn) {
679 const auto& setPredicate =
680 internal::checked_pointer_cast<BoundSetPredicate>(pred);
681 return TransformSet(name, setPredicate, func);
682 }
683 return nullptr;
684 }
685
686 // Handle literal predicates based on source type
687 const auto& literalPredicate =
688 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
689
690 switch (func->source_type()->type_id()) {
691 case TypeId::kInt:
692 case TypeId::kLong:
693 case TypeId::kDecimal:
694 return TransformNumericStrict(name, literalPredicate, func);
695 case TypeId::kString:
696 return TruncateStringLiteralStrict(name, literalPredicate, func);
697 case TypeId::kBinary:
698 return TruncateByteArrayStrict(name, literalPredicate, func);
699 default:
700 return NotSupported("{} is not a valid input type for truncate transform",
701 func->source_type()->ToString());
702 }
703 }
704
705 static Result<std::unique_ptr<UnboundPredicate>> TemporalProjectStrict(
706 std::string_view name, const std::shared_ptr<BoundPredicate>& pred,
707 const std::shared_ptr<TransformFunction>& func) {
708 ICEBERG_ASSIGN_OR_RAISE(auto ref, NamedReference::Make(std::string(name)));
709 if (pred->kind() == BoundPredicate::Kind::kUnary) {
710 return UnboundPredicateImpl<BoundReference>::Make(pred->op(), std::move(ref));
711 } else if (pred->kind() == BoundPredicate::Kind::kLiteral) {
712 const auto& literalPredicate =
713 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
714 ICEBERG_ASSIGN_OR_RAISE(auto projected,
715 TransformNumericStrict(name, literalPredicate, func));
716 if (func->transform_type() != TransformType::kDay ||
717 func->source_type()->type_id() != TypeId::kDate) {
718 return FixStrictTimeProjection(
719 internal::checked_pointer_cast<UnboundPredicateImpl<BoundReference>>(
720 std::move(projected)));
721 }
722 return projected;
723 } else if (pred->kind() == BoundPredicate::Kind::kSet &&
724 pred->op() == Expression::Operation::kNotIn) {
725 const auto& setPredicate = internal::checked_pointer_cast<BoundSetPredicate>(pred);
726 ICEBERG_ASSIGN_OR_RAISE(auto projected, TransformSet(name, setPredicate, func));
727 if (func->transform_type() != TransformType::kDay ||
728 func->source_type()->type_id() != TypeId::kDate) {
729 return FixStrictTimeProjection(
730 internal::checked_pointer_cast<UnboundPredicateImpl<BoundReference>>(
731 std::move(projected)));
732 }
733 return projected;
734 }
735
736 return nullptr;
737 }
738};
739
740} // namespace iceberg
Checked cast functions for dynamic_cast and static_cast. Adapted from Apache Arrow https://github....
Represents 128-bit fixed-point decimal numbers. The max decimal precision that can be safely represen...
Definition decimal.h:46
Operation
Operation types for expressions.
Definition expression.h:40
Literal is a literal value that is associated with a primitive type.
Definition literal.h:39
const Value & value() const
Get the literal value.
Definition literal.h:109
const std::shared_ptr< PrimitiveType > & type() const
Get the literal type.
Definition literal.cc:364
static Literal Decimal(int128_t value, int32_t precision, int32_t scale)
Create a decimal literal.
Definition literal.cc:349
static Result< std::unique_ptr< NamedReference > > Make(std::string field_name)
Create a named reference to a field.
Definition term.cc:45
Definition projection_util_internal.h:43
static size_t CodePointCount(std::string_view str)
Count the number of code points in a UTF-8 string.
Definition string_util.h:66
static Result< std::unique_ptr< UnboundPredicateImpl< B > > > Make(Expression::Operation op, std::shared_ptr< UnboundTerm< B > > term)
Create an unbound predicate (unary operation).
Definition predicate.cc:47
128-bit fixed-point decimal numbers. Adapted from Apache Arrow with only Decimal128 support....