45 static Result<Literal> AdjustLiteral(
const Literal& literal,
int adjustment) {
46 switch (literal.
type()->type_id()) {
48 return Literal::Int(std::get<int32_t>(literal.
value()) + adjustment);
50 return Literal::Long(std::get<int64_t>(literal.
value()) + adjustment);
52 return Literal::Date(std::get<int32_t>(literal.
value()) + adjustment);
53 case TypeId::kTimestamp:
54 return Literal::Timestamp(std::get<int64_t>(literal.
value()) + adjustment);
55 case TypeId::kTimestampTz:
56 return Literal::TimestampTz(std::get<int64_t>(literal.
value()) + adjustment);
57 case TypeId::kDecimal: {
58 const auto& decimal_type =
59 internal::checked_cast<const DecimalType&>(*literal.
type());
62 decimal_type.scale());
65 return NotSupported(
"{} is not a valid literal type for value adjustment",
66 literal.
type()->ToString());
70 static Result<Literal> PlusOne(
const Literal& literal) {
71 return AdjustLiteral(literal, +1);
74 static Result<Literal> MinusOne(
const Literal& literal) {
75 return AdjustLiteral(literal, -1);
78 static Result<std::unique_ptr<UnboundPredicate>> MakePredicate(
80 const std::shared_ptr<TransformFunction>& func,
const Literal& literal) {
82 ICEBERG_ASSIGN_OR_RAISE(
auto lit, func->Transform(literal));
86 static Result<std::unique_ptr<UnboundPredicate>> TransformSet(
87 std::string_view name,
const std::shared_ptr<BoundSetPredicate>& pred,
88 const std::shared_ptr<TransformFunction>& func) {
89 std::vector<Literal> transformed;
90 transformed.reserve(pred->literal_set().size());
91 for (
const auto& lit : pred->literal_set()) {
92 ICEBERG_ASSIGN_OR_RAISE(
auto transformed_lit, func->Transform(lit));
93 transformed.push_back(std::move(transformed_lit));
97 std::move(transformed));
100 static Result<std::unique_ptr<UnboundPredicate>> TruncateByteArray(
101 std::string_view name,
const std::shared_ptr<BoundLiteralPredicate>& pred,
102 const std::shared_ptr<TransformFunction>& func) {
103 switch (pred->op()) {
104 case Expression::Operation::kLt:
105 case Expression::Operation::kLtEq:
106 return MakePredicate(Expression::Operation::kLtEq, name, func, pred->literal());
107 case Expression::Operation::kGt:
108 case Expression::Operation::kGtEq:
109 return MakePredicate(Expression::Operation::kGtEq, name, func, pred->literal());
110 case Expression::Operation::kEq:
111 case Expression::Operation::kStartsWith:
112 return MakePredicate(pred->op(), name, func, pred->literal());
118 static Result<std::unique_ptr<UnboundPredicate>> TruncateByteArrayStrict(
119 std::string_view name,
const std::shared_ptr<BoundLiteralPredicate>& pred,
120 const std::shared_ptr<TransformFunction>& func) {
121 switch (pred->op()) {
122 case Expression::Operation::kLt:
123 case Expression::Operation::kLtEq:
124 return MakePredicate(Expression::Operation::kLt, name, func, pred->literal());
125 case Expression::Operation::kGt:
126 case Expression::Operation::kGtEq:
127 return MakePredicate(Expression::Operation::kGt, name, func, pred->literal());
128 case Expression::Operation::kNotEq:
129 return MakePredicate(Expression::Operation::kNotEq, name, func, pred->literal());
136 static Result<std::unique_ptr<UnboundPredicate>> TransformNumeric(
137 std::string_view name,
const std::shared_ptr<BoundLiteralPredicate>& pred,
138 const std::shared_ptr<TransformFunction>& func) {
139 switch (func->source_type()->type_id()) {
142 case TypeId::kDecimal:
144 case TypeId::kTimestamp:
145 case TypeId::kTimestampTz:
148 return NotSupported(
"{} is not a valid input type for numeric transform",
149 func->source_type()->ToString());
152 switch (pred->op()) {
153 case Expression::Operation::kLt: {
155 ICEBERG_ASSIGN_OR_RAISE(
auto adjusted, MinusOne(pred->literal()));
156 return MakePredicate(Expression::Operation::kLtEq, name, func, adjusted);
158 case Expression::Operation::kGt: {
160 ICEBERG_ASSIGN_OR_RAISE(
auto adjusted, PlusOne(pred->literal()));
161 return MakePredicate(Expression::Operation::kGtEq, name, func, adjusted);
163 case Expression::Operation::kLtEq:
164 case Expression::Operation::kGtEq:
165 case Expression::Operation::kEq:
166 return MakePredicate(pred->op(), name, func, pred->literal());
172 static Result<std::unique_ptr<UnboundPredicate>> TransformNumericStrict(
173 std::string_view name,
const std::shared_ptr<BoundLiteralPredicate>& pred,
174 const std::shared_ptr<TransformFunction>& func) {
175 switch (func->source_type()->type_id()) {
178 case TypeId::kDecimal:
180 case TypeId::kTimestamp:
181 case TypeId::kTimestampTz:
184 return NotSupported(
"{} is not a valid input type for numeric transform",
185 func->source_type()->ToString());
188 switch (pred->op()) {
189 case Expression::Operation::kLtEq: {
190 ICEBERG_ASSIGN_OR_RAISE(
auto adjusted, PlusOne(pred->literal()));
191 return MakePredicate(Expression::Operation::kLt, name, func, adjusted);
193 case Expression::Operation::kGtEq: {
194 ICEBERG_ASSIGN_OR_RAISE(
auto adjusted, MinusOne(pred->literal()));
195 return MakePredicate(Expression::Operation::kGt, name, func, adjusted);
197 case Expression::Operation::kLt:
198 case Expression::Operation::kGt:
199 case Expression::Operation::kNotEq:
200 return MakePredicate(pred->op(), name, func, pred->literal());
206 static Result<std::unique_ptr<UnboundPredicate>> TruncateStringLiteral(
207 std::string_view name,
const std::shared_ptr<BoundLiteralPredicate>& pred,
208 const std::shared_ptr<TransformFunction>& func) {
209 const auto op = pred->op();
210 if (op != Expression::Operation::kStartsWith &&
211 op != Expression::Operation::kNotStartsWith) {
212 return TruncateByteArray(name, pred, func);
215 const auto& literal = pred->literal();
218 const auto width =
static_cast<size_t>(
219 internal::checked_pointer_cast<TruncateTransform>(func)->width());
221 if (length < width) {
222 return MakePredicate(op, name, func, literal);
225 if (length == width) {
226 if (op == Expression::Operation::kStartsWith) {
227 return MakePredicate(Expression::Operation::kEq, name, func, literal);
229 return MakePredicate(Expression::Operation::kNotEq, name, func, literal);
233 if (op == Expression::Operation::kStartsWith) {
234 return TruncateByteArray(name, pred, func);
240 static Result<std::unique_ptr<UnboundPredicate>> TruncateStringLiteralStrict(
241 std::string_view name,
const std::shared_ptr<BoundLiteralPredicate>& pred,
242 const std::shared_ptr<TransformFunction>& func) {
243 const auto op = pred->op();
244 if (op != Expression::Operation::kStartsWith &&
245 op != Expression::Operation::kNotStartsWith) {
246 return TruncateByteArrayStrict(name, pred, func);
249 const auto& literal = pred->literal();
252 const auto width =
static_cast<size_t>(
253 internal::checked_pointer_cast<TruncateTransform>(func)->width());
255 if (length < width) {
256 return MakePredicate(op, name, func, literal);
259 if (length == width) {
260 if (op == Expression::Operation::kStartsWith) {
261 return MakePredicate(Expression::Operation::kEq, name, func, literal);
263 return MakePredicate(Expression::Operation::kNotEq, name, func, literal);
267 if (op == Expression::Operation::kNotStartsWith) {
268 return MakePredicate(Expression::Operation::kNotStartsWith, name, func, literal);
277 static Result<std::unique_ptr<UnboundPredicate>> FixInclusiveTimeProjection(
278 std::unique_ptr<UnboundPredicateImpl<BoundReference>> projected) {
279 if (projected ==
nullptr) {
285 switch (projected->op()) {
286 case Expression::Operation::kLt: {
287 ICEBERG_DCHECK(!projected->literals().empty(),
"Expected at least one literal");
288 const auto& literal = projected->literals().front();
289 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.
value()),
291 if (
auto value = std::get<int32_t>(literal.
value()); value < 0) {
293 std::move(projected->term()),
294 Literal::Int(value + 1));
300 case Expression::Operation::kLtEq: {
301 ICEBERG_DCHECK(!projected->literals().empty(),
"Expected at least one literal");
302 const auto& literal = projected->literals().front();
303 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.
value()),
306 if (
auto value = std::get<int32_t>(literal.
value()); value < 0) {
308 std::move(projected->term()),
309 Literal::Int(value + 1));
314 case Expression::Operation::kGt:
315 case Expression::Operation::kGtEq:
319 case Expression::Operation::kEq: {
320 ICEBERG_DCHECK(!projected->literals().empty(),
"Expected at least one literal");
321 const auto& literal = projected->literals().front();
322 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.
value()),
324 if (
auto value = std::get<int32_t>(literal.
value()); value < 0) {
328 Expression::Operation::kIn, std::move(projected->term()),
329 {literal, Literal::Int(value + 1)});
334 case Expression::Operation::kIn: {
335 ICEBERG_DCHECK(!projected->literals().empty(),
"Expected at least one literal");
336 const auto& literals = projected->literals();
338 std::ranges::all_of(literals,
339 [](
const auto& lit) {
340 return std::holds_alternative<int32_t>(lit.value());
343 std::unordered_set<int32_t> value_set;
344 bool has_negative_value =
false;
345 for (
const auto& lit : literals) {
346 auto value = std::get<int32_t>(lit.value());
347 value_set.insert(value);
349 value_set.insert(value + 1);
350 has_negative_value =
true;
353 if (has_negative_value) {
355 std::views::transform(value_set,
356 [](int32_t value) {
return Literal::Int(value); }) |
357 std::ranges::to<std::vector>();
359 std::move(projected->term()),
365 case Expression::Operation::kNotIn:
366 case Expression::Operation::kNotEq:
378 static Result<std::unique_ptr<UnboundPredicate>> FixStrictTimeProjection(
379 std::unique_ptr<UnboundPredicateImpl<BoundReference>> projected) {
380 if (projected ==
nullptr) {
384 switch (projected->op()) {
385 case Expression::Operation::kLt:
386 case Expression::Operation::kLtEq:
391 case Expression::Operation::kGt: {
395 ICEBERG_DCHECK(!projected->literals().empty(),
"Expected at least one literal");
396 const auto& literal = projected->literals().front();
397 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.
value()),
399 if (
auto value = std::get<int32_t>(literal.
value()); value <= 0) {
401 std::move(projected->term()),
402 Literal::Int(value + 1));
407 case Expression::Operation::kGtEq: {
408 ICEBERG_DCHECK(!projected->literals().empty(),
"Expected at least one literal");
409 const auto& literal = projected->literals().front();
410 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.
value()),
412 if (
auto value = std::get<int32_t>(literal.
value()); value <= 0) {
414 std::move(projected->term()),
415 Literal::Int(value + 1));
420 case Expression::Operation::kEq:
421 case Expression::Operation::kIn:
425 case Expression::Operation::kNotEq: {
426 ICEBERG_DCHECK(!projected->literals().empty(),
"Expected at least one literal");
427 const auto& literal = projected->literals().front();
428 ICEBERG_DCHECK(std::holds_alternative<int32_t>(literal.
value()),
430 if (
auto value = std::get<int32_t>(literal.
value()); value < 0) {
432 Expression::Operation::kNotIn, std::move(projected->term()),
433 {literal, Literal::Int(value + 1)});
438 case Expression::Operation::kNotIn: {
439 ICEBERG_DCHECK(!projected->literals().empty(),
"Expected at least one literal");
440 const auto& literals = projected->literals();
442 std::ranges::all_of(literals,
443 [](
const auto& lit) {
444 return std::holds_alternative<int32_t>(lit.value());
447 std::unordered_set<int32_t> value_set;
448 bool has_negative_value =
false;
449 for (
const auto& lit : literals) {
450 auto value = std::get<int32_t>(lit.value());
451 value_set.insert(value);
453 value_set.insert(value + 1);
454 has_negative_value =
true;
457 if (has_negative_value) {
459 std::views::transform(value_set,
460 [](int32_t value) {
return Literal::Int(value); }) |
461 std::ranges::to<std::vector>();
463 std::move(projected->term()),
475 static Result<std::unique_ptr<UnboundPredicate>> IdentityProject(
476 std::string_view name,
const std::shared_ptr<BoundPredicate>& pred) {
478 switch (pred->kind()) {
479 case BoundPredicate::Kind::kUnary: {
482 case BoundPredicate::Kind::kLiteral: {
483 const auto& literalPredicate =
484 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
486 literalPredicate->literal());
488 case BoundPredicate::Kind::kSet: {
489 const auto& setPredicate =
490 internal::checked_pointer_cast<BoundSetPredicate>(pred);
492 pred->op(), std::move(ref),
493 std::vector<Literal>(setPredicate->literal_set().begin(),
494 setPredicate->literal_set().end()));
500 static Result<std::unique_ptr<UnboundPredicate>> BucketProject(
501 std::string_view name,
const std::shared_ptr<BoundPredicate>& pred,
502 const std::shared_ptr<TransformFunction>& func) {
504 switch (pred->kind()) {
505 case BoundPredicate::Kind::kUnary: {
508 case BoundPredicate::Kind::kLiteral: {
509 if (pred->op() == Expression::Operation::kEq) {
510 const auto& literalPredicate =
511 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
512 ICEBERG_ASSIGN_OR_RAISE(
auto transformed,
513 func->Transform(literalPredicate->literal()));
515 std::move(transformed));
519 case BoundPredicate::Kind::kSet: {
521 if (pred->op() == Expression::Operation::kIn) {
522 const auto& setPredicate =
523 internal::checked_pointer_cast<BoundSetPredicate>(pred);
524 return TransformSet(name, setPredicate, func);
536 static Result<std::unique_ptr<UnboundPredicate>> TruncateProject(
537 std::string_view name,
const std::shared_ptr<BoundPredicate>& pred,
538 const std::shared_ptr<TransformFunction>& func) {
541 if (pred->kind() == BoundPredicate::Kind::kUnary) {
546 if (pred->kind() == BoundPredicate::Kind::kSet) {
547 if (pred->op() == Expression::Operation::kIn) {
548 const auto& setPredicate =
549 internal::checked_pointer_cast<BoundSetPredicate>(pred);
550 return TransformSet(name, setPredicate, func);
556 const auto& literalPredicate =
557 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
559 switch (func->source_type()->type_id()) {
562 case TypeId::kDecimal:
563 return TransformNumeric(name, literalPredicate, func);
564 case TypeId::kString:
565 return TruncateStringLiteral(name, literalPredicate, func);
566 case TypeId::kBinary:
567 return TruncateByteArray(name, literalPredicate, func);
569 return NotSupported(
"{} is not a valid input type for truncate transform",
570 func->source_type()->ToString());
574 static Result<std::unique_ptr<UnboundPredicate>> TemporalProject(
575 std::string_view name,
const std::shared_ptr<BoundPredicate>& pred,
576 const std::shared_ptr<TransformFunction>& func) {
578 if (pred->kind() == BoundPredicate::Kind::kUnary) {
580 }
else if (pred->kind() == BoundPredicate::Kind::kLiteral) {
581 const auto& literalPredicate =
582 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
583 ICEBERG_ASSIGN_OR_RAISE(
auto projected,
584 TransformNumeric(name, literalPredicate, func));
585 if (func->transform_type() != TransformType::kDay ||
586 func->source_type()->type_id() != TypeId::kDate) {
587 return FixInclusiveTimeProjection(
588 internal::checked_pointer_cast<UnboundPredicateImpl<BoundReference>>(
589 std::move(projected)));
592 }
else if (pred->kind() == BoundPredicate::Kind::kSet &&
593 pred->op() == Expression::Operation::kIn) {
594 const auto& setPredicate = internal::checked_pointer_cast<BoundSetPredicate>(pred);
595 ICEBERG_ASSIGN_OR_RAISE(
auto projected, TransformSet(name, setPredicate, func));
596 if (func->transform_type() != TransformType::kDay ||
597 func->source_type()->type_id() != TypeId::kDate) {
598 return FixInclusiveTimeProjection(
599 internal::checked_pointer_cast<UnboundPredicateImpl<BoundReference>>(
600 std::move(projected)));
608 static Result<std::unique_ptr<UnboundPredicate>> RemoveTransform(
609 std::string_view name,
const std::shared_ptr<BoundPredicate>& pred) {
611 switch (pred->kind()) {
612 case BoundPredicate::Kind::kUnary: {
615 case BoundPredicate::Kind::kLiteral: {
616 const auto& literalPredicate =
617 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
619 literalPredicate->literal());
621 case BoundPredicate::Kind::kSet: {
622 const auto& setPredicate =
623 internal::checked_pointer_cast<BoundSetPredicate>(pred);
625 pred->op(), std::move(ref),
626 std::vector<Literal>(setPredicate->literal_set().begin(),
627 setPredicate->literal_set().end()));
633 static Result<std::unique_ptr<UnboundPredicate>> BucketProjectStrict(
634 std::string_view name,
const std::shared_ptr<BoundPredicate>& pred,
635 const std::shared_ptr<TransformFunction>& func) {
637 switch (pred->kind()) {
638 case BoundPredicate::Kind::kUnary: {
641 case BoundPredicate::Kind::kLiteral: {
642 if (pred->op() == Expression::Operation::kNotEq) {
643 const auto& literalPredicate =
644 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
645 ICEBERG_ASSIGN_OR_RAISE(
auto transformed,
646 func->Transform(literalPredicate->literal()));
649 std::move(transformed));
653 case BoundPredicate::Kind::kSet: {
654 if (pred->op() == Expression::Operation::kNotIn) {
655 const auto& setPredicate =
656 internal::checked_pointer_cast<BoundSetPredicate>(pred);
657 return TransformSet(name, setPredicate, func);
667 static Result<std::unique_ptr<UnboundPredicate>> TruncateProjectStrict(
668 std::string_view name,
const std::shared_ptr<BoundPredicate>& pred,
669 const std::shared_ptr<TransformFunction>& func) {
672 if (pred->kind() == BoundPredicate::Kind::kUnary) {
677 if (pred->kind() == BoundPredicate::Kind::kSet) {
678 if (pred->op() == Expression::Operation::kNotIn) {
679 const auto& setPredicate =
680 internal::checked_pointer_cast<BoundSetPredicate>(pred);
681 return TransformSet(name, setPredicate, func);
687 const auto& literalPredicate =
688 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
690 switch (func->source_type()->type_id()) {
693 case TypeId::kDecimal:
694 return TransformNumericStrict(name, literalPredicate, func);
695 case TypeId::kString:
696 return TruncateStringLiteralStrict(name, literalPredicate, func);
697 case TypeId::kBinary:
698 return TruncateByteArrayStrict(name, literalPredicate, func);
700 return NotSupported(
"{} is not a valid input type for truncate transform",
701 func->source_type()->ToString());
705 static Result<std::unique_ptr<UnboundPredicate>> TemporalProjectStrict(
706 std::string_view name,
const std::shared_ptr<BoundPredicate>& pred,
707 const std::shared_ptr<TransformFunction>& func) {
709 if (pred->kind() == BoundPredicate::Kind::kUnary) {
711 }
else if (pred->kind() == BoundPredicate::Kind::kLiteral) {
712 const auto& literalPredicate =
713 internal::checked_pointer_cast<BoundLiteralPredicate>(pred);
714 ICEBERG_ASSIGN_OR_RAISE(
auto projected,
715 TransformNumericStrict(name, literalPredicate, func));
716 if (func->transform_type() != TransformType::kDay ||
717 func->source_type()->type_id() != TypeId::kDate) {
718 return FixStrictTimeProjection(
719 internal::checked_pointer_cast<UnboundPredicateImpl<BoundReference>>(
720 std::move(projected)));
723 }
else if (pred->kind() == BoundPredicate::Kind::kSet &&
724 pred->op() == Expression::Operation::kNotIn) {
725 const auto& setPredicate = internal::checked_pointer_cast<BoundSetPredicate>(pred);
726 ICEBERG_ASSIGN_OR_RAISE(
auto projected, TransformSet(name, setPredicate, func));
727 if (func->transform_type() != TransformType::kDay ||
728 func->source_type()->type_id() != TypeId::kDate) {
729 return FixStrictTimeProjection(
730 internal::checked_pointer_cast<UnboundPredicateImpl<BoundReference>>(
731 std::move(projected)));