8000 tsaucer/run TPC-H examples in CI (#711) · datapythonista/datafusion-python@dfbb3ca · GitHub
[go: up one dir, main page]

Skip to content

Commit dfbb3ca

Browse files
authored
tsaucer/run TPC-H examples in CI (apache#711)
* Mostly small updates to tpc-h examples to make their results consistent with spec * Mostly prepares tpch examples for CI by allowing us to check the path for the files regardless of where it is run from. There are a few small updates to make the tests match the expected answer file provided by dbgen. * Expose the substring command * Add the script to run all tpch examples in pytest * Update tpch generator script to allow for non-interactive terminals, such as when running in CI * Add tpch examples to github workflow for testing
1 parent f12a487 commit dfbb3ca

28 files changed

+311
-113
lines changed

.github/workflows/test.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,23 @@ jobs:
111111
source venv/bin/activate
112112
pip install -e . -vv
113113
pytest -v .
114+
115+
- name: Cache the generated dataset
116+
id: cache-tpch-dataset
117+
uses: actions/cache@v3
118+
with:
119+
path: benchmarks/tpch/data
120+
key: tpch-data-2.18.0
121+
122+
- name: Run dbgen to create 1 Gb dataset
123+
if: ${{ steps.cache-tpch-dataset.outputs.cache-hit != 'true' }}
124+
run: |
125+
cd benchmarks/tpch
126+
RUN_IN_CI=TRUE ./tpch-gen.sh 1
127+
128+
- name: Run TPC-H examples
129+
run: |
130+
source venv/bin/activate
131+
cd examples/tpch
132+
python convert_data_to_parquet.py
133+
pytest _tests.py

benchmarks/tpch/tpch-gen.sh

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ mkdir -p data/answers 2>/dev/null
2020

2121
set -e
2222

23+
# If RUN_IN_CI is set, then do not produce verbose output or use an interactive terminal
24+
if [[ -z "${RUN_IN_CI}" ]]; then
25+
TERMINAL_FLAG="-it"
26+
VERBOSE_OUTPUT="-vf"
27+
else
28+
TERMINAL_FLAG=""
29+
VERBOSE_OUTPUT="-f"
30+
fi
31+
2332
#pushd ..
2433
#. ./dev/build-set-env.sh
2534
#popd
@@ -29,7 +38,7 @@ FILE=./data/supplier.tbl
2938
if test -f "$FILE"; then
3039
echo "$FILE exists."
3140
else
32-
docker run -v `pwd`/data:/data -it --rm ghcr.io/scalytics/tpch-docker:main -vf -s $1
41+
docker run -v `pwd`/data:/data $TERMINAL_FLAG --rm ghcr.io/scalytics/tpch-docker:main $VERBOSE_OUTPUT -s $1
3342

3443
# workaround for https://github.com/apache/arrow-datafusion/issues/6147
3544
mv data/customer.tbl data/customer.csv
@@ -49,5 +58,5 @@ FILE=./data/answers/q1.out
4958
if test -f "$FILE"; then
5059
echo "$FILE exists."
5160
else
52-
docker run -v `pwd`/data:/data -it --entrypoint /bin/bash --rm ghcr.io/scalytics/tpch-docker:main -c "cp /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/"
61+
docker run -v `pwd`/data:/data $TERMINAL_FLAG --entrypoint /bin/bash --rm ghcr.io/scalytics/tpch-docker:main -c "cp /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/"
5362
fi

examples/tpch/_tests.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pytest
19+
from importlib import import_module
20+
import pyarrow as pa
21+
from datafusion import col, lit, functions as F
22+
from util import get_answer_file
23+
24+
def df_selection(col_name, col_type):
25+
if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type):
26+
return F.round(col(col_name), lit(2)).alias(col_name)
27+
elif col_type == pa.string():
28+
return F.trim(col(col_name)).alias(col_name)
29+
else:
30+
return col(col_name)
31+
32+
def load_schema(col_name, col_type):
33+
if col_type == pa.int64() or col_type == pa.int32():
34+
return col_name, pa.string()
35+
elif isinstance(col_type, pa.Decimal128Type):
36+
return col_name, pa.float64()
37+
else:
38+
return col_name, col_type
39+
40+
def expected_selection(col_name, col_type):
41+
if col_type == pa.int64() or col_type == pa.int32():
42+
return F.trim(col(col_name)).cast(col_type).alias(col_name)
43+
elif col_type == pa.string():
44+
return F.trim(col(col_name)).alias(col_name)
45+
else:
46+
return col(col_name)
47+
48+
def selections_and_schema(original_schema):
49+
columns = [ (c, original_schema.field(c).type) for c in original_schema.names ]
50+
51+
df_selections = [ df_selection(c, t) for (c, t) in columns]
52+
expected_schema = [ load_schema(c, t) for (c, t) in columns]
53+
expected_selections = [ expected_selection(c, t) for (c, t) in columns]
54+
55+
return (df_selections, expected_schema, expected_selections)
56+
57+
def check_q17(df):
58+
raw_value = float(df.collect()[0]["avg_yearly"][0].as_py())
59+
value = round(raw_value, 2)
60+
assert abs(value - 348406.05) < 0.001
61+
62+
@pytest.mark.parametrize(
63+
("query_code", "answer_file"),
64+
[
65+
("q01_pricing_summary_report", "q1"),
66+
("q02_minimum_cost_supplier", "q2"),
67+
("q03_shipping_priority", "q3"),
68+
("q04_order_priority_checking", "q4"),
69+
("q05_local_supplier_volume", "q5"),
70+
("q06_forecasting_revenue_change", "q6"),
71+
("q07_volume_shipping", "q7"),
72+
("q08_market_share", "q8"),
73+
("q09_product_type_profit_measure", "q9"),
74+
("q10_returned_item_reporting", "q10"),
75+
("q11_important_stock_identification", "q11"),
76+
("q12_ship_mode_order_priority", "q12"),
77+
("q13_customer_distribution", "q13"),
78+
("q14_promotion_effect", "q14"),
79+
("q15_top_supplier", "q15"),
80+
("q16_part_supplier_relationship", "q16"),
81+
("q17_small_quantity_order", "q17"),
82+
("q18_large_volume_customer", "q18"),
83+
("q19_discounted_revenue", "q19"),
84+
("q20_potential_part_promotion", "q20"),
85+
("q21_suppliers_kept_orders_waiting", "q21"),
86+
("q22_global_sales_opportunity", "q22"),
87+
],
88+
)
89+
def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
90+
module = import_module(query_code)
91+
df = module.df
92+
93+
# Treat q17 as a special case. The answer file does not match the spec. Running at
94+
# scale factor 1, we have manually verified this result does match the expected value.
95+
if answer_file == "q17":
96+
return check_q17(df)
97+
98+
(df_selections, expected_schema, expected_selections) = selections_and_schema(df.schema())
99+
100+
df = df.select(*df_selections)
101+
102+
read_schema = pa.schema(expected_schema)
103+
104+
df_expected = module.ctx.read_csv(get_answer_file(answer_file), schema=read_schema, delimiter="|", file_extension=".out")
105+
106+
df_expected = df_expected.select(*expected_selections)
107+
108+
cols = list(read_schema.names)
109+
110+
assert df.join(df_expected, (cols, cols), "anti").count() == 0
111+
assert df.count() == df_expected.count()

examples/tpch/convert_data_to_parquet.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
("C_ADDRESS", pyarrow.string()),
3737
("C_NATIONKEY", pyarrow.int32()),
3838
("C_PHONE", pyarrow.string()),
39-
("C_ACCTBAL", pyarrow.float32()),
39+
("C_ACCTBAL", pyarrow.decimal128(15, 2)),
4040
("C_MKTSEGMENT", pyarrow.string()),
4141
("C_COMMENT", pyarrow.string()),
4242
]
@@ -46,10 +46,10 @@
4646
("L_PARTKEY", pyarrow.int32()),
4747
("L_SUPPKEY", pyarrow.int32()),
4848
("L_LINENUMBER", pyarrow.int32()),
49-
("L_QUANTITY", pyarrow.float32()),
50-
("L_EXTENDEDPRICE", pyarrow.float32()),
51-
("L_DISCOUNT", pyarrow.float32()),
52-
("L_TAX", pyarrow.float32()),
49+
("L_QUANTITY", pyarrow.decimal128(15, 2)),
50+
("L_EXTENDEDPRICE", pyarrow.decimal128(15, 2)),
51+
("L_DISCOUNT", pyarrow.decimal128(15, 2)),
52+
("L_TAX", pyarrow.decimal128(15, 2)),
5353
("L_RETURNFLAG", pyarrow.string()),
5454
("L_LINESTATUS", pyarrow.string()),
5555
("L_SHIPDATE", pyarrow.date32()),
@@ -71,7 +71,7 @@
7171
("O_ORDERKEY", pyarrow.int32()),
7272
("O_CUSTKEY", pyarrow.int32()),
7373
("O_ORDERSTATUS", pyarrow.string()),
74-
("O_TOTALPRICE", pyarrow.float32()),
74+
("O_TOTALPRICE", pyarrow.decimal128(15, 2)),
7575
("O_ORDERDATE", pyarrow.date32()),
7676
("O_ORDERPRIORITY", pyarrow.string()),
7777
("O_CLERK", pyarrow.string()),
@@ -87,15 +87,15 @@
8787
("P_TYPE", pyarrow.string()),
8888
("P_SIZE", pyarrow.int32()),
8989
("P_CONTAINER", pyarrow.string()),
90-
("P_RETAILPRICE", pyarrow.float32()),
90+
("P_RETAILPRICE", pyarrow.decimal128(15, 2)),
9191
("P_COMMENT", pyarrow.string()),
9292
]
9393

9494
all_schemas["partsupp"] = [
9595
("PS_PARTKEY", pyarrow.int32()),
9696
("PS_SUPPKEY", pyarrow.int32()),
9797
("PS_AVAILQTY", pyarrow.int32()),
98-
("PS_SUPPLYCOST", pyarrow.float32()),
98+
("PS_SUPPLYCOST", pyarrow.decimal128(15, 2)),
9999
("PS_COMMENT", pyarrow.string()),
100100
]
101101

@@ -111,7 +111,7 @@
111111
("S_ADDRESS", pyarrow.string()),
112112
("S_NATIONKEY", pyarrow.int32()),
113113
("S_PHONE", pyarrow.string()),
114-
("S_ACCTBAL", pyarrow.float32()),
114+
("S_ACCTBAL", pyarrow.decimal128(15, 2)),
115115
("S_COMMENT", pyarrow.string()),
116116
]
117117

examples/tpch/q01_pricing_summary_report.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131

3232
import pyarrow as pa
3333
from datafusion import SessionContext, col, lit, functions as F
34+
from util import get_data_path
3435

3536
ctx = SessionContext()
3637

37-
df = ctx.read_parquet("data/lineitem.parquet")
38+
df = ctx.read_parquet(get_data_path("lineitem.parquet"))
3839

3940
# It may be that the date can be hard coded, based on examples shown.
4041
# This approach will work with any date range in the provided data set.
@@ -45,7 +46,7 @@
4546

4647
# From the given problem, this is how close to the last date in the database we
4748
# want to report results for. It should be between 60-120 days before the end.
48-
DAYS_BEFORE_FINAL = 68
49+
DAYS_BEFORE_FINAL = 90
4950

5051
# Note: this is a hack on setting the values. It should be set differently once
5152
# https://github.com/apache/datafusion-python/issues/665 is resolved.
@@ -63,13 +64,13 @@
6364
[
6465
F.sum(col("l_quantity")).alias("sum_qty"),
6566
F.sum(col("l_extendedprice")).alias("sum_base_price"),
66-
F.sum(col("l_extendedprice") * (lit(1.0) - col("l_discount"))).alias(
67+
F.sum(col("l_extendedprice") * (lit(1) - col("l_discount"))).alias(
6768
"sum_disc_price"
6869
),
6970
F.sum(
7071
col("l_extendedprice")
71-
* (lit(1.0) - col("l_discount"))
72-
* (lit(1.0) + col("l_tax"))
72+
* (lit(1) - col("l_discount"))
73+
* (lit(1) + col("l_tax"))
7374
).alias("sum_charge"),
7475
F.avg(col("l_quantity")).alias("avg_qty"),
7576
F.avg(col("l_extendedprice")).alias("avg_price"),

examples/tpch/q02_minimum_cost_supplier.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131

3232
import datafusion
3333
from datafusion import SessionContext, col, lit, functions as F
34+
from util import get_data_path
3435

35-
# This is the part we're looking for
36+
# This is the part we're looking for. Values selected here differ from the spec in order to run
37+
# unit tests on a small data set.
3638
SIZE_OF_INTEREST = 15
3739
TYPE_OF_INTEREST = "BRASS"
3840
REGION_OF_INTEREST = "EUROPE"
@@ -41,10 +43,10 @@
4143

4244
ctx = SessionContext()
4345

44-
df_part = ctx.read_parquet("data/part.parquet").select_columns(
46+
df_part = ctx.read_parquet(get_data_path("part.parquet")).select_columns(
4547
"p_partkey", "p_mfgr", "p_type", "p_size"
4648
)
47-
df_supplier = ctx.read_parquet("data/supplier.parquet").select_columns(
49+
df_supplier = ctx.read_parquet(get_data_path("supplier.parquet")).select_columns(
4850
"s_acctbal",
4951
"s_name",
5052
"s_address",
@@ -53,13 +55,13 @@
5355
"s_nationkey",
5456
"s_suppkey",
5557
)
56-
df_partsupp = ctx.read_parquet("data/partsupp.parquet").select_columns(
58+
df_partsupp = ctx.read_parquet(get_data_path("partsupp.parquet")).select_columns(
5759
"ps_partkey", "ps_suppkey", "ps_supplycost"
5860
)
59-
df_nation = ctx.read_parquet("data/nation.parquet").select_columns(
61+
df_nation = ctx.read_parquet(get_data_path("nation.parquet")).select_columns(
6062
"n_nationkey", "n_regionkey", "n_name"
6163
)
62-
df_region = ctx.read_parquet("data/region.parquet").select_columns(
64+
df_region = ctx.read_parquet(get_data_path("region.parquet")).select_columns(
6365
"r_regionkey", "r_name"
6466
)
6567

examples/tpch/q03_shipping_priority.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"""
2929

3030
from datafusion import SessionContext, col, lit, functions as F
31+
from util import get_data_path
3132

3233
SEGMENT_OF_INTEREST = "BUILDING"
3334
DATE_OF_INTEREST = "1995-03-15"
@@ -36,13 +37,13 @@
3637

3738
ctx = SessionContext()
3839

39-
df_customer = ctx.read_parquet("data/customer.parquet").select_columns(
40+
df_customer = ctx.read_parquet(get_data_path("customer.parquet")).select_columns(
4041
"c_mktsegment", "c_custkey"
4142
)
42-
df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
43+
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
4344
"o_orderdate", "o_shippriority", "o_custkey", "o_orderkey"
4445
)
45-
df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
46+
df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
4647
"l_orderkey", "l_extendedprice", "l_discount", "l_shipdate"
4748
)
4849

@@ -73,9 +74,9 @@
7374

7475
df = df.sort(col("revenue").sort(ascending=False), col("o_orderdate").sort())
7576

76-
# Only return 100 results
77+
# Only return 10 results
7778

78-
df = df.limit(100)
79+
df = df.limit(10)
7980

8081
# Change the order that the columns are reported in just to match the spec
8182

examples/tpch/q04_order_priority_checking.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from datetime import datetime
3030
import pyarrow as pa
3131
from datafusion import SessionContext, col, lit, functions as F
32+
from util import get_data_path
3233

3334
# Ideally we could put 3 months into the interval. See note below.
3435
INTERVAL_DAYS = 92
@@ -38,10 +39,10 @@
3839

3940
ctx = SessionContext()
4041

41-
df_orders = ctx.read_parquet("data/orders.parquet").select_columns(
42+
df_orders = ctx.read_parquet(get_data_path("orders.parquet")).select_columns(
4243
"o_orderdate", "o_orderpriority", "o_orderkey"
4344
)
44-
df_lineitem = ctx.read_parquet("data/lineitem.parquet").select_columns(
45+
df_lineitem = ctx.read_parquet(get_data_path("lineitem.parquet")).select_columns(
4546
"l_orderkey", "l_commitdate", "l_receiptdate"
4647
)
4748

0 commit comments

Comments
 (0)
0