Skip to content

Commit

Permalink
Allow varbinary to varchar coercion for hive tables
Browse files Browse the repository at this point in the history
  • Loading branch information
Praveen2112 committed Jul 3, 2024
1 parent cf5b073 commit e9f78fa
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/src/main/sphinx/connector/hive.md
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,8 @@ type conversions.
- `VARCHAR`
* - `TIMESTAMP`
- `VARCHAR`, `DATE`
* - `VARBINARY`
- `VARCHAR`
:::

Any conversion failure results in null, which is the same behavior
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;

import java.util.List;
Expand All @@ -79,6 +80,7 @@
import static io.trino.plugin.hive.coercions.DecimalCoercers.createRealToDecimalCoercer;
import static io.trino.plugin.hive.coercions.DoubleToVarcharCoercers.createDoubleToVarcharCoercer;
import static io.trino.plugin.hive.coercions.FloatToVarcharCoercers.createFloatToVarcharCoercer;
import static io.trino.plugin.hive.coercions.VarbinaryToVarcharCoercers.createVarbinaryToVarcharCoercer;
import static io.trino.plugin.hive.coercions.VarcharToIntegralNumericCoercers.createVarcharToIntegerNumberCoercer;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.block.ColumnarArray.toColumnarArray;
Expand Down Expand Up @@ -273,6 +275,10 @@ public static Type createTypeFromCoercer(TypeManager typeManager, HiveType fromH
coercionContext);
}

if (fromType instanceof VarbinaryType && toType instanceof VarcharType varcharType) {
return Optional.of(createVarbinaryToVarcharCoercer(varcharType, coercionContext.storageFormat()));
}

throw new TrinoException(NOT_SUPPORTED, format("Unsupported coercion from %s to %s", fromHiveType, toHiveType));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.hive.coercions;

import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.plugin.hive.HiveStorageFormat;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;

import java.nio.charset.CharacterCodingException;
import java.nio.charset.CharsetDecoder;
import java.util.HexFormat;

import static io.trino.plugin.hive.HiveStorageFormat.ORC;
import static io.trino.plugin.hive.HiveStorageFormat.PARQUET;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.spi.type.Varchars.truncateToLength;
import static java.nio.charset.CodingErrorAction.REPLACE;
import static java.nio.charset.StandardCharsets.UTF_8;

public class VarbinaryToVarcharCoercers
{
private VarbinaryToVarcharCoercers() {}

public static TypeCoercer<VarbinaryType, VarcharType> createVarbinaryToVarcharCoercer(VarcharType toType, HiveStorageFormat storageFormat)
{
if (storageFormat == ORC) {
return new OrcVarbinaryToVarcharCoercer(toType);
}
if (storageFormat == PARQUET) {
return new ParquetVarbinaryToVarcharCoercer(toType);
}
return new VarbinaryToVarcharCoercer(toType);
}

private static class VarbinaryToVarcharCoercer
extends TypeCoercer<VarbinaryType, VarcharType>
{
public VarbinaryToVarcharCoercer(VarcharType toType)
{
super(VARBINARY, toType);
}

@Override
protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position)
{
try {
Slice decodedValue = fromType.getSlice(block, position);
if (toType.isUnbounded()) {
toType.writeSlice(blockBuilder, decodedValue);
return;
}
toType.writeSlice(blockBuilder, truncateToLength(decodedValue, toType.getBoundedLength()));
}
catch (RuntimeException e) {
blockBuilder.appendNull();
}
}
}

private static class ParquetVarbinaryToVarcharCoercer
extends TypeCoercer<VarbinaryType, VarcharType>
{
public ParquetVarbinaryToVarcharCoercer(VarcharType toType)
{
super(VARBINARY, toType);
}

@Override
protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position)
{
// Hive's coercion logic for Varbinary to Varchar
// https://github.com/apache/hive/blob/8190d2be7b7165effa62bd21b7d60ef81fb0e4af/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java#L911
// It uses Hadoop's Text#decode replaces malformed input with a substitution character i.e U+FFFD
// https://github.com/apache/hadoop/blob/706d88266abcee09ed78fbaa0ad5f74d818ab0e9/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/io/Text.java#L414
CharsetDecoder decoder = UTF_8.newDecoder()
.onMalformedInput(REPLACE)
.onUnmappableCharacter(REPLACE);

try {
Slice decodedValue = Slices.utf8Slice(decoder.decode(fromType.getSlice(block, position).toByteBuffer()).toString());
if (toType.isUnbounded()) {
toType.writeSlice(blockBuilder, decodedValue);
return;
}
toType.writeSlice(blockBuilder, truncateToLength(decodedValue, toType.getBoundedLength()));
}
catch (CharacterCodingException e) {
blockBuilder.appendNull();
}
}
}

private static class OrcVarbinaryToVarcharCoercer
extends TypeCoercer<VarbinaryType, VarcharType>
{
private static final HexFormat HEX_FORMAT = HexFormat.of().withDelimiter(" ");

public OrcVarbinaryToVarcharCoercer(VarcharType toType)
{
super(VARBINARY, toType);
}

@Override
protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position)
{
Slice value = fromType.getSlice(block, position);
Slice hexValue = Slices.utf8Slice(HEX_FORMAT.formatHex(value.byteArray(), value.byteArrayOffset(), value.length()));
if (toType.isUnbounded()) {
toType.writeSlice(blockBuilder, hexValue);
return;
}
toType.writeSlice(blockBuilder, truncateToLength(hexValue, toType.getBoundedLength()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.orc.metadata.OrcType.OrcTypeKind.BINARY;
import static io.trino.orc.metadata.OrcType.OrcTypeKind.BOOLEAN;
import static io.trino.orc.metadata.OrcType.OrcTypeKind.BYTE;
import static io.trino.orc.metadata.OrcType.OrcTypeKind.DATE;
Expand All @@ -74,6 +75,7 @@
import static io.trino.orc.metadata.OrcType.OrcTypeKind.STRUCT;
import static io.trino.orc.metadata.OrcType.OrcTypeKind.TIMESTAMP;
import static io.trino.orc.metadata.OrcType.OrcTypeKind.VARCHAR;
import static io.trino.plugin.hive.HiveStorageFormat.ORC;
import static io.trino.plugin.hive.coercions.DecimalCoercers.createDecimalToDecimalCoercer;
import static io.trino.plugin.hive.coercions.DecimalCoercers.createDecimalToDoubleCoercer;
import static io.trino.plugin.hive.coercions.DecimalCoercers.createDecimalToInteger;
Expand All @@ -84,6 +86,7 @@
import static io.trino.plugin.hive.coercions.DecimalCoercers.createRealToDecimalCoercer;
import static io.trino.plugin.hive.coercions.DoubleToVarcharCoercers.createDoubleToVarcharCoercer;
import static io.trino.plugin.hive.coercions.FloatToVarcharCoercers.createFloatToVarcharCoercer;
import static io.trino.plugin.hive.coercions.VarbinaryToVarcharCoercers.createVarbinaryToVarcharCoercer;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.SmallintType.SMALLINT;
Expand Down Expand Up @@ -284,6 +287,10 @@ private OrcTypeTranslator() {}
}
return Optional.empty();
}

if (fromOrcTypeKind == BINARY && toTrinoType instanceof VarcharType varcharType) {
return Optional.of(createVarbinaryToVarcharCoercer(varcharType, ORC));
}
return Optional.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;

import java.util.List;
Expand Down Expand Up @@ -87,7 +88,8 @@ private boolean canCoerce(HiveType fromHiveType, HiveType toHiveType, HiveTimest
fromHiveType.equals(HIVE_FLOAT) ||
fromHiveType.equals(HIVE_DOUBLE) ||
fromHiveType.equals(HIVE_DATE) ||
fromType instanceof DecimalType;
fromType instanceof DecimalType ||
fromType instanceof VarbinaryType;
}
if (toHiveType.equals(HIVE_DATE)) {
return fromHiveType.equals(HIVE_TIMESTAMP);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.hive.coercions;

import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.plugin.hive.HiveStorageFormat;
import io.trino.plugin.hive.coercions.CoercionUtils.CoercionContext;
import io.trino.spi.block.Block;
import io.trino.spi.type.Type;
import org.junit.jupiter.api.Test;

import static io.trino.plugin.hive.HiveStorageFormat.ORC;
import static io.trino.plugin.hive.HiveStorageFormat.PARQUET;
import static io.trino.plugin.hive.HiveStorageFormat.RCTEXT;
import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION;
import static io.trino.plugin.hive.HiveType.toHiveType;
import static io.trino.plugin.hive.coercions.CoercionUtils.createCoercer;
import static io.trino.spi.predicate.Utils.blockToNativeValue;
import static io.trino.spi.predicate.Utils.nativeValueToBlock;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER;
import static org.assertj.core.api.Assertions.assertThat;

public class TestVarbinaryToVarcharCoercer
{
private static final byte CONTINUATION_BYTE = (byte) 0b1011_1111;
private static final byte START_4_BYTE = (byte) 0b1111_0111;
private static final byte X_CHAR = (byte) 'X';

// Test cases are copied from https://github.com/airlift/slice/blob/master/src/test/java/io/airlift/slice/TestSliceUtf8.java

@Test
public void testVarbinaryToVarcharCoercion()
{
assertVarbinaryToVarcharCoercion(Slices.utf8Slice("abc"), VARBINARY, Slices.utf8Slice("abc"), VARCHAR);
assertVarbinaryToVarcharCoercion(Slices.utf8Slice("abc"), VARBINARY, Slices.utf8Slice("ab"), createVarcharType(2));
// Invalid UTF-8 encoding
assertVarbinaryToVarcharCoercion(Slices.wrappedBuffer(X_CHAR, CONTINUATION_BYTE), VARBINARY, Slices.wrappedBuffer(X_CHAR, CONTINUATION_BYTE), VARCHAR);
assertVarbinaryToVarcharCoercion(
Slices.wrappedBuffer(X_CHAR, START_4_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE),
VARBINARY,
Slices.wrappedBuffer(X_CHAR, START_4_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE),
VARCHAR);
assertVarbinaryToVarcharCoercion(
Slices.wrappedBuffer(X_CHAR, START_4_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, X_CHAR),
VARBINARY,
Slices.wrappedBuffer(X_CHAR, START_4_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, X_CHAR),
VARCHAR);
assertVarbinaryToVarcharCoercion(
Slices.wrappedBuffer(X_CHAR, (byte) 0b11101101, (byte) 0xA0, (byte) 0x80),
VARBINARY,
Slices.wrappedBuffer(X_CHAR, (byte) 0b11101101, (byte) 0xA0, (byte) 0x80),
VARCHAR);
assertVarbinaryToVarcharCoercion(
Slices.wrappedBuffer(X_CHAR, (byte) 0b11101101, (byte) 0xBF, (byte) 0xBF),
VARBINARY,
Slices.wrappedBuffer(X_CHAR, (byte) 0b11101101, (byte) 0xBF, (byte) 0xBF),
VARCHAR);
}

@Test
public void testVarbinaryToVarcharCoercionForParquet()
{
assertVarbinaryToVarcharCoercionForParquet(Slices.utf8Slice("abc"), VARBINARY, "abc", VARCHAR);
assertVarbinaryToVarcharCoercionForParquet(Slices.utf8Slice("abc"), VARBINARY, "ab", createVarcharType(2));
// Invalid UTF-8 encoding
assertVarbinaryToVarcharCoercionForParquet(Slices.wrappedBuffer(X_CHAR, CONTINUATION_BYTE), VARBINARY, "X\uFFFD", VARCHAR);
assertVarbinaryToVarcharCoercionForParquet(Slices.wrappedBuffer(X_CHAR, START_4_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE), VARBINARY, "X\uFFFD\uFFFD\uFFFD\uFFFD", VARCHAR);
assertVarbinaryToVarcharCoercionForParquet(Slices.wrappedBuffer(X_CHAR, START_4_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, X_CHAR), VARBINARY, "X\uFFFD\uFFFD\uFFFD\uFFFDX", VARCHAR);
assertVarbinaryToVarcharCoercionForParquet(Slices.wrappedBuffer(X_CHAR, (byte) 0b11101101, (byte) 0xA0, (byte) 0x80), VARBINARY, "X\uFFFD", VARCHAR);
assertVarbinaryToVarcharCoercionForParquet(Slices.wrappedBuffer(X_CHAR, (byte) 0b11101101, (byte) 0xBF, (byte) 0xBF), VARBINARY, "X\uFFFD", VARCHAR);
}

@Test
public void testVarbinaryToVarcharCoercionForOrc()
{
assertVarbinaryToVarcharCoercionForOrc(Slices.utf8Slice("abc"), VARBINARY, "61 62 63", VARCHAR);
assertVarbinaryToVarcharCoercionForOrc(Slices.utf8Slice("abc"), VARBINARY, "61", createVarcharType(2));
// Invalid UTF-8 encoding
assertVarbinaryToVarcharCoercionForOrc(Slices.wrappedBuffer(X_CHAR, CONTINUATION_BYTE), VARBINARY, "58 bf", VARCHAR);
assertVarbinaryToVarcharCoercionForOrc(Slices.wrappedBuffer(X_CHAR, START_4_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE), VARBINARY, "58 f7 bf bf bf", VARCHAR);
assertVarbinaryToVarcharCoercionForOrc(Slices.wrappedBuffer(X_CHAR, START_4_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, CONTINUATION_BYTE, X_CHAR), VARBINARY, "58 f7 bf bf bf 58", VARCHAR);
assertVarbinaryToVarcharCoercionForOrc(Slices.wrappedBuffer(X_CHAR, (byte) 0b11101101, (byte) 0xA0, (byte) 0x80), VARBINARY, "58 ed a0 80", VARCHAR);
assertVarbinaryToVarcharCoercionForOrc(Slices.wrappedBuffer(X_CHAR, (byte) 0b11101101, (byte) 0xBF, (byte) 0xBF), VARBINARY, "58 ed bf bf", VARCHAR);
}

private static void assertVarbinaryToVarcharCoercion(Slice actualValue, Type fromType, Slice expectedValue, Type toType)
{
assertVarbinaryToVarcharCoercion(actualValue, fromType, expectedValue, toType, RCTEXT);
}

private static void assertVarbinaryToVarcharCoercionForOrc(Slice actualValue, Type fromType, String expectedValue, Type toType)
{
assertVarbinaryToVarcharCoercion(actualValue, fromType, Slices.utf8Slice(expectedValue), toType, ORC);
}

private static void assertVarbinaryToVarcharCoercionForParquet(Slice actualValue, Type fromType, String expectedValue, Type toType)
{
assertVarbinaryToVarcharCoercion(actualValue, fromType, Slices.utf8Slice(expectedValue), toType, PARQUET);
}

private static void assertVarbinaryToVarcharCoercion(Slice actualValue, Type fromType, Slice expectedValue, Type toType, HiveStorageFormat storageFormat)
{
Block coercedBlock = createCoercer(TESTING_TYPE_MANAGER, toHiveType(fromType), toHiveType(toType), new CoercionContext(DEFAULT_PRECISION, storageFormat)).orElseThrow()
.apply(nativeValueToBlock(fromType, actualValue));
assertThat(blockToNativeValue(toType, coercedBlock))
.isEqualTo(expectedValue);
}
}
Loading

0 comments on commit e9f78fa

Please sign in to comment.