@@ -20,9 +20,13 @@ package org.apache.spark.sql.catalyst.analysis
20
20
import org .apache .spark .sql .catalyst .expressions .{Cast , DefaultStringProducingExpression , Expression , Literal , SubqueryExpression }
21
21
import org .apache .spark .sql .catalyst .plans .logical .{AddColumns , AlterColumns , AlterColumnSpec , AlterViewAs , ColumnDefinition , CreateTable , CreateTempView , CreateView , LogicalPlan , QualifiedColType , ReplaceColumns , ReplaceTable , TableSpec , V2CreateTablePlan }
22
22
import org .apache .spark .sql .catalyst .rules .Rule
23
- import org .apache .spark .sql .connector .catalog .{SupportsNamespaces , TableCatalog }
23
+ import org .apache .spark .sql .catalyst .trees .CurrentOrigin
24
+ import org .apache .spark .sql .catalyst .util .CharVarcharUtils .CHAR_VARCHAR_TYPE_STRING_METADATA_KEY
25
+ import org .apache .spark .sql .connector .catalog .{CatalogV2Util , SupportsNamespaces , Table , TableCatalog }
24
26
import org .apache .spark .sql .connector .catalog .SupportsNamespaces .PROP_COLLATION
25
- import org .apache .spark .sql .types .{DataType , StringType }
27
+ import org .apache .spark .sql .errors .DataTypeErrors .toSQLId
28
+ import org .apache .spark .sql .errors .QueryCompilationErrors
29
+ import org .apache .spark .sql .types .{CharType , DataType , StringType , StructField , VarcharType }
26
30
27
31
/**
28
32
* Resolves string types in logical plans by assigning them the appropriate collation. The
@@ -33,12 +37,13 @@ import org.apache.spark.sql.types.{DataType, StringType}
33
37
*/
34
38
object ApplyDefaultCollationToStringType extends Rule [LogicalPlan ] {
35
39
def apply (plan : LogicalPlan ): LogicalPlan = {
36
- val planWithResolvedDefaultCollation = resolveDefaultCollation(plan)
40
+ val preprocessedPlan = Seq (resolveDefaultCollation _, resolveAlterColumnsDataType _)
41
+ .foldLeft(plan) { case (currentPlan, resolver) => resolver(currentPlan) }
37
42
38
- fetchDefaultCollation(planWithResolvedDefaultCollation ) match {
43
+ fetchDefaultCollation(preprocessedPlan ) match {
39
44
case Some (collation) =>
40
- transform(planWithResolvedDefaultCollation , StringType (collation))
41
- case None => planWithResolvedDefaultCollation
45
+ transform(preprocessedPlan , StringType (collation))
46
+ case None => preprocessedPlan
42
47
}
43
48
}
44
49
@@ -63,10 +68,14 @@ object ApplyDefaultCollationToStringType extends Rule[LogicalPlan] {
63
68
case ReplaceTable (_ : ResolvedIdentifier , _, _, tableSpec : TableSpec , _) =>
64
69
tableSpec.collation
65
70
66
- // In `transform` we handle these 3 ALTER TABLE commands.
67
- case cmd : AddColumns => getCollationFromTableProps(cmd.table)
68
- case cmd : ReplaceColumns => getCollationFromTableProps(cmd.table)
69
- case cmd : AlterColumns => getCollationFromTableProps(cmd.table)
71
+ case AddColumns (resolvedTable : ResolvedTable , _) =>
72
+ Option (resolvedTable.table.properties.get(TableCatalog .PROP_COLLATION ))
73
+
74
+ case ReplaceColumns (resolvedTable : ResolvedTable , _) =>
75
+ Option (resolvedTable.table.properties.get(TableCatalog .PROP_COLLATION ))
76
+
77
+ case AlterColumns (resolvedTable : ResolvedTable , _) =>
78
+ Option (resolvedTable.table.properties.get(TableCatalog .PROP_COLLATION ))
70
79
71
80
case alterViewAs : AlterViewAs =>
72
81
alterViewAs.child match {
@@ -85,15 +94,6 @@ object ApplyDefaultCollationToStringType extends Rule[LogicalPlan] {
85
94
}
86
95
}
87
96
88
- private def getCollationFromTableProps (t : LogicalPlan ): Option [String ] = {
89
- t match {
90
- case resolvedTbl : ResolvedTable
91
- if resolvedTbl.table.properties.containsKey(TableCatalog .PROP_COLLATION ) =>
92
- Some (resolvedTbl.table.properties.get(TableCatalog .PROP_COLLATION ))
93
- case _ => None
94
- }
95
- }
96
-
97
97
/**
98
98
* Determines the default collation for an object in the following order:
99
99
* 1. Use the object's explicitly defined default collation, if available.
@@ -168,22 +168,86 @@ object ApplyDefaultCollationToStringType extends Rule[LogicalPlan] {
168
168
case p if isCreateOrAlterPlan(p) || AnalysisContext .get.collation.isDefined =>
169
169
transformPlan(p, newType)
170
170
171
- case addCols : AddColumns =>
171
+ case addCols@ AddColumns ( _ : ResolvedTable , _) =>
172
172
addCols.copy(column
341A
sToAdd = replaceColumnTypes(addCols.columnsToAdd, newType))
173
173
174
- case replaceCols : ReplaceColumns =>
174
+ case replaceCols@ ReplaceColumns ( _ : ResolvedTable , _) =>
175
175
replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, newType))
176
176
177
- case a @ AlterColumns (_ , specs : Seq [AlterColumnSpec ]) =>
177
+ case a @ AlterColumns (ResolvedTable (_, _, table : Table , _) , specs : Seq [AlterColumnSpec ]) =>
178
178
val newSpecs = specs.map {
179
- case spec if spec.newDataType.isDefined && hasDefaultStringType (spec.newDataType.get ) =>
179
+ case spec if shouldApplyDefaultCollationToAlterColumn (spec, table ) =>
180
180
spec.copy(newDataType = Some (replaceDefaultStringType(spec.newDataType.get, newType)))
181
181
case col => col
182
182
}
183
183
a.copy(specs = newSpecs)
184
184
}
185
185
}
186
186
187
+ /**
188
+ * The column type should not be changed if the original column type is [[StringType ]] and the new
189
+ * type is the default [[StringType ]] (i.e., [[StringType ]] without an explicit collation).
190
+ *
191
+ * Query Example:
192
+ * {{{
193
+ * CREATE TABLE t (c1 STRING COLLATE UNICODE)
194
+ * ALTER TABLE t ALTER COLUMN c1 TYPE STRING -- c1 will remain STRING COLLATE UNICODE
195
+ * }}}
196
+ */
197
+ private def resolveAlterColumnsDataType (plan : LogicalPlan ): LogicalPlan = {
198
+ plan match {
199
+ case alterColumns@ AlterColumns (
200
+ ResolvedTable (_, _, table : Table , _), specs : Seq [AlterColumnSpec ]) =>
201
+ val resolvedSpecs = specs.map { spec =>
202
+ if (spec.newDataType.isDefined && isStringTypeColumn(spec.column, table) &&
203
+ isDefaultStringType(spec.newDataType.get)) {
204
+ spec.copy(newDataType = None )
205
+ } else {
206
+ spec
207
+ }
208
+ }
209
+ val newAlterColumns = CurrentOrigin .withOrigin(alterColumns.origin) {
210
+ alterColumns.copy(specs = resolvedSpecs)
211
+ }
212
+ newAlterColumns.copyTagsFrom(alterColumns)
213
+ newAlterColumns
214
+ case _ =>
215
+ plan
216
+ }
217
+ }
218
+
219
+ private def shouldApplyDefaultCollationToAlterColumn (
220
+ alterColumnSpec : AlterColumnSpec , table : Table ): Boolean = {
221
+ alterColumnSpec.newDataType.isDefined &&
222
+ // Applies the default collation only if the original column's type is not StringType.
223
+ ! isStringTypeColumn(alterColumnSpec.column, table) &&
224
+ hasDefaultStringType(alterColumnSpec.newDataType.get)
225
+ }
226
+
227
+ /**
228
+ * Checks whether the column's [[DataType ]] is [[StringType ]] in the given table. Throws an error
229
+ * if the column is not found.
230
+ */
231
+ private def isStringTypeColumn (fieldName : FieldName , table : Table ): Boolean = {
232
+ CatalogV2Util .v2ColumnsToStructType(table.columns())
233
+ .findNestedField(fieldName.name, includeCollections = true , resolver = conf.resolver)
234
+ .map {
235
+ case (_, StructField (_, _ : CharType , _, _)) =>
236
+ false
237
+ case (_, StructField (_, _ : VarcharType , _, _)) =>
238
+ false
239
+ case (_, StructField (_, _ : StringType , _, metadata))
240
+ if ! metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY ) =>
241
+ true
242
+ case (_, _) =>
243
+ false
244
+ }
245
+ .getOrElse {
246
+ throw QueryCompilationErrors .unresolvedColumnError(
247
+ toSQLId(fieldName.name), table.columns().map(_.name))
248
+ }
249
+ }
250
+
187
251
/**
188
252
* Transforms the given plan, by transforming all expressions in its operators to use the given
189
253
* new type instead of the default string type.
0 commit comments