From aea90529d45a6ce2091807e66a3bf62b7dee43c3 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Wed, 11 Oct 2023 13:36:41 -0400 Subject: [PATCH] Add type inference for BIGINT and TINYINT types Signed-off-by: Jesse Whitehouse --- src/databricks/sql/utils.py | 14 ++++++++ tests/unit/test_parameters.py | 63 +++++++++++++++++++++++------------ 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index ae5160ef1..5265380fd 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -533,6 +533,16 @@ def named_parameters_to_dbsqlparams_v2(parameters: List[Any]): return dbsqlparams +def resolve_databricks_sql_integer_type(integer): + """Returns the smallest Databricks SQL integer type that can contain the passed integer""" + if -128 <= integer <= 127: + return DbSqlType.TINYINT + elif -2147483648 <= integer <= 2147483647: + return DbSqlType.INTEGER + else: + return DbSqlType.BIGINT + + def infer_types(params: list[DbSqlParameter]): type_lookup_table = { str: DbSqlType.STRING, @@ -568,6 +578,10 @@ def infer_types(params: list[DbSqlParameter]): cast_exp = calculate_decimal_cast_string(param.value) _type = DbsqlDynamicDecimalType(cast_exp) + # int() requires special handling because one Python type can be cast to multiple SQL types (INT, BIGINT, TINYINT) + if _type == DbSqlType.INTEGER: + _type = resolve_databricks_sql_integer_type(param.value) + # VOID / NULL types must be passed in a unique way as TSparkParameters with no value if _type == DbSqlType.VOID: new_params.append(DbSqlParameter(name=_name, type=DbSqlType.VOID)) diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index b131ea7cc..1370def35 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -19,28 +19,36 @@ class TestTSparkParameterConversion(object): - def test_conversion_e2e(self): + @pytest.mark.parametrize( + "input_value, expected_type", + [ + ("a", "STRING"), + (1, "TINYINT"), + (1000, "INTEGER"), + (9223372036854775807, "BIGINT"), # Max value of a signed 64-bit integer + (True, "BOOLEAN"), + (1.0, "FLOAT"), + ], + ) + def test_conversion_e2e(self, input_value, expected_type): """This behaviour falls back to Python's default string formatting of numbers""" - assert named_parameters_to_tsparkparams( - ["a", 1, True, 1.0, DbSqlParameter(value="1.0", type=DbSqlType.DECIMAL)] - ) == [ - TSparkParameter( - name="", type="STRING", value=TSparkParameterValue(stringValue="a") - ), - TSparkParameter( - name="", type="INTEGER", value=TSparkParameterValue(stringValue="1") - ), - TSparkParameter( - name="", type="BOOLEAN", value=TSparkParameterValue(stringValue="True") - ), - TSparkParameter( - name="", type="FLOAT", value=TSparkParameterValue(stringValue="1.0") - ), + output = named_parameters_to_tsparkparams([input_value]) + expected = TSparkParameter( + name="", + type=expected_type, + value=TSparkParameterValue(stringValue=str(input_value)), + ) + assert output == [expected] + + def test_conversion_e2e_decimal(self): + input = DbSqlParameter(value="1.0", type=DbSqlType.DECIMAL) + output = named_parameters_to_tsparkparams([input]) + assert output == [ TSparkParameter( name="", type="DECIMAL(2,1)", value=TSparkParameterValue(stringValue="1.0"), - ), + ) ] def test_basic_conversions_v1(self): @@ -69,10 +77,24 @@ def test_infer_types_dict(self): with pytest.raises(ValueError): infer_types([DbSqlParameter("", {1: 1})]) - def test_infer_types_integer(self): - input = DbSqlParameter("", 1) + @pytest.mark.parametrize( + "input_value, expected_type", + [ + (-128, DbSqlType.TINYINT), + (127, DbSqlType.TINYINT), + (-2147483649, DbSqlType.BIGINT), + (-2147483648, DbSqlType.INTEGER), + (2147483647, DbSqlType.INTEGER), + (-9223372036854775808, DbSqlType.BIGINT), + (9223372036854775807, DbSqlType.BIGINT), + ], + ) + def test_infer_types_integer(self, input_value, expected_type): + input = DbSqlParameter("", input_value) output = infer_types([input]) - assert output == [DbSqlParameter("", "1", DbSqlType.INTEGER)] + assert output == [ + DbSqlParameter("", str(input_value), expected_type) + ], f"{output[0].type} received, expected {expected_type}" def test_infer_types_boolean(self): input = DbSqlParameter("", True) @@ -101,7 +123,6 @@ def test_infer_types_decimal(self): assert x.type.value == "DECIMAL(2,1)" def test_infer_types_none(self): - input = DbSqlParameter("", None) output: List[DbSqlParameter] = infer_types([input])