8000 get-started: Use dvclive · iterative/example-repos-dev@95ce158 · GitHub
[go: up one dir, main page]

Skip to content

Commit 95ce158

Browse files
daavooshcheklein
authored andcommitted
get-started: Use dvclive
1 parent bec1b66 commit 95ce158

File tree

11 files changed

+133
-52
lines changed

11 files changed

+133
-52
lines changed

example-get-started/analyze.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import io
2+
import os
3+
import random
4+
import re
5+
import sys
6+
import xml.etree.ElementTree
7+
8+
9+
if len(sys.argv) != 3:
10+
sys.stderr.write("Arguments error. Usage:\n")
11+
sys.stderr.write("\tpython analyze.py data-file output-file\n")
12+
sys.exit(1)
13+
14+
target = 40000
15+
split = 0.3
16+
17+
18+
def lines_matched_test(fd, test):
19+
for line in fd:
20+
try:
21+
attr = xml.etree.ElementTree.fromstring(line).attrib
22+
if test(attr.get("Tags", "")):
23+
yield line
24+
except Exception as ex:
25+
sys.stderr.write(f"Skipping the broken line: {ex}\n")
26+
27+
28+
def process_posts(fd_in, fd_not, fd_out):
29+
count = 0
30+
in_lines = lines_matched_test(fd_in, lambda x: "<r>" in x)
31+
not_lines = lines_matched_test(fd_not, lambda x: "<r>" not in x)
32+
while count < target:
33+
line = next(not_lines) if random.random() > split else next(in_lines)
34+
fd_out.write(line)
35+
count += 1
36+
37+
38+
with io.open(sys.argv[1], encoding="utf8") as fd_in:
39+
with io.open(sys.argv[1], encoding="utf8") as fd_not:
40+
with io.open(sys.argv[2], "w", encoding="utf8") as fd_out:
41+
process_posts(fd_in, fd_not, fd_out)
42+

example-get-started/code/.github/workflows/cml.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,15 @@ jobs:
2323
echo "# CML Report" > report.md
2424
echo "## Plots" >> report.md
2525
dvc plots diff $PREVIOUS_REF workspace \
26-
--show-vega --targets prc.json > vega.json
26+
--show-vega --targets evaluation/plots/precision_recall.json > vega.json
2727
vl2svg vega.json prc.svg
2828
cml publish prc.svg --title "Precision & Recall" --md >> report.md
2929
30+
dvc plots show \
31+
--show-vega evaluation/plots/predictions.json > vega.json
32+
vl2svg vega.json confusion.svg
33+
cml publish confusion.svg --title "Confusion Matrix" --md >> report.md
34+
3035
echo "## Metrics" >> report.md
3136
echo "### $PREVIOUS_REF → workspace" >> report.md
3237
dvc metrics diff $PREVIOUS_REF --show-md >> report.md

example-get-started/code/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ introduction into basic DVC concepts.
88

99
The project is a natural language processing (NLP) binary classifier problem of
1010
predicting tags for a given StackOverflow question. For example, we want one
11-
classifier which can predict a post that is about the Python language by tagging
12-
it `python`.
< 8000 /td>11+
classifier which can predict a post that is about the R language by tagging it
12+
`R`.
1313

1414
🐛 Please report any issues found in this project here -
1515
[example-repos-dev](https://github.com/iterative/example-repos-dev).
@@ -160,3 +160,4 @@ $ tree
160160
├── requirements.txt # <-- Python dependencies needed in the project
161161
└── train.py
162162
```
163+

example-get-started/code/params.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@ prepare:
33
seed: 20170428
44

55
featurize:
6-
max_features: 500
6+
max_features: 100
77
ngrams: 1
88

99
train:
1010
seed: 20170428
1111
n_est: 50
1212
min_split: 2
13+

example-get-started/code/src/evaluate.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,48 @@
44
import pickle
55
import sys
66

7-
import sklearn.metrics as metrics
7+
import pandas as pd
8+
from sklearn import metrics
9+
from sklearn import tree
10+
from dvclive import Live
11+
from matplotlib import pyplot as plt
812

9-
if len(sys.argv) != 6:
13+
14+
live = Live("evaluation")
15+
16+
if len(sys.argv) != 3:
1017
sys.stderr.write("Arguments error. Usage:\n")
11-
sys.stderr.write("\tpython evaluate.py model features scores prc roc\n")
18+
sys.stderr.write("\tpython evaluate.py model features\n")
1219
sys.exit(1)
1320

1421
model_file = sys.argv[1]
1522
matrix_file = os.path.join(sys.argv[2], "test.pkl")
16-
scores_file = sys.argv[3]
17-
prc_file = sys.argv[4]
18-
roc_file = sys.argv[5]
1923

2024
with open(model_file, "rb") as fd:
2125
model = pickle.load(fd)
2226

2327
with open(matrix_file, "rb") as fd:
24-
matrix = pickle.load(fd)
28+
matrix, feature_names = pickle.load(fd)
2529

2630
labels = matrix[:, 1].toarray()
2731
x = matrix[:, 2:]
2832

2933
predictions_by_class = model.predict_proba(x)
3034
predictions = predictions_by_class[:, 1]
3135

32-
precision, recall, prc_thresholds = metrics.precision_recall_curve(labels, predictions)
33-
fpr, tpr, roc_thresholds = metrics.roc_curve(labels, predictions)
34-
35-
avg_prec = metrics.average_precision_score(labels, predictions)
36-
roc_auc = metrics.roc_auc_score(labels, predictions)
37-
38-
with open(scores_file, "w") as fd:
39-
json.dump({"avg_prec": avg_prec, "roc_auc": roc_auc}, fd, indent=4)
36+
# Use dvclive to log a few simple plots ...
37+
live.log_plot("roc", labels, predictions)
38+
live.log("avg_prec", metrics.average_precision_score(labels, predictions))
39+
live.log("roc_auc", metrics.roc_auc_score(labels, predictions))
4040

41+
# ... but actually it can be done with dumping data points into a file:
4142
# ROC has a drop_intermediate arg that reduces the number of points.
4243
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#sklearn.metrics.roc_curve.
4344
# PRC lacks this arg, so we manually reduce to 1000 points as a rough estimate.
45+
precision, recall, prc_thresholds = metrics.precision_recall_curve(labels, predictions)
4446
nth_point = math.ceil(len(prc_thresholds) / 1000)
4547
prc_points = list(zip(precision, recall, prc_thresholds))[::nth_point]
48+
prc_file = "evaluation/plots/precision_recall.json"
4649
with open(prc_file, "w") as fd:
4750
json.dump(
4851
{
@@ -55,14 +58,21 @@
5558
indent=4,
5659
)
5760

58-
with open(roc_file, "w") as fd:
59-
json.dump(
60-
{
61-
"roc": [
62-
{"fpr": fp, "tpr": tp, "threshold": t}
63-
for fp, tp, t in zip(fpr, tpr, roc_thresholds)
64-
]
65-
},
66-
fd,
67-
indent=4,
68-
)
61+
62+
# ... confusion matrix plot
63+
predictions = [{
64+
"actual": int(actual),
65+
"predicted": 1 if predicted > 0.5 else 0
66+
} for actual, predicted in zip(labels, predictions)]
67+
with open("evaluation/plots/predictions.json", "w") as f:
68+
json.dump(predictions, f)
69+
70+
# ... and finally, we can dump an image, it's also supported:
71+
fig, axes = plt.subplots(dpi=800)
72+
fig.subplots_adjust(bottom=0.2, top=0.95)
73+
importances = model.feature_importances_
74+
forest_importances = pd.Series(importances, index=feature_names).nlargest(n=30)
75+
axes.set_ylabel("Mean decrease in impurity")
76+
forest_importances.plot.bar(ax=axes)
77+
fig.savefig('evaluation/importance.png')
78+

example-get-started/code/src/featurization.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def get_df(data):
3838
return df
3939

4040

41-
def save_matrix(df, matrix, output):
41+
def save_matrix(df, matrix, names, output):
4242
id_matrix = sparse.csr_matrix(df.id.astype(np.int64)).T
4343
label_matrix = sparse.csr_matrix(df.label.astype(np.int64)).T
4444

@@ -48,7 +48,7 @@ def save_matrix(df, matrix, output):
4848
sys.stderr.write(msg.format(output, result.shape, result.dtype))
4949

5050
with open(output, "wb") as fd:
51-
pickle.dump(result, fd)
51+
pickle.dump((result, names), fd)
5252
pass
5353

5454

@@ -64,16 +64,18 @@ def save_matrix(df, matrix, output):
6464

6565
bag_of_words.fit(train_words)
6666
train_words_binary_matrix = bag_of_words.transform(train_words)
67+
feature_names = bag_of_words.get_feature_names_out()
6768
tfidf = TfidfTransformer(smooth_idf=False)
6869
tfidf.fit(train_words_binary_matrix)
6970
train_words_tfidf_matrix = tfidf.transform(train_words_binary_matrix)
7071

71-
save_matrix(df_train, train_words_tfidf_matrix, train_output)
72+
save_matrix(df_train, train_words_tfidf_matrix, feature_names, train_output)
7273

7374
# Generate test feature matrix
7475
df_test = get_df(test_input)
7576
test_words = np.array(df_test.text.str.lower().values.astype("U"))
7677
test_words_binary_matrix = bag_of_words.transform(test_words)
7778
test_words_tfidf_matrix = tfidf.transform(test_words_binary_matrix)
7879

79-
save_matrix(df_test, test_words_tfidf_matrix, test_output)
80+
save_matrix(df_test, test_words_tfidf_matrix, feature_names, test_output)
81+

example-get-started/code/src/prepare.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,5 @@ def process_posts(fd_in, fd_out_train, fd_out_test, target_tag):
4848
with io.open(input, encoding="utf8") as fd_in:
4949
with io.open(output_train, "w", encoding="utf8") as fd_out_train:
5050
with io.open(output_test, "w", encoding="utf8") as fd_out_test:
51-
process_posts(fd_in, fd_out_train, fd_out_test, "<python>")
51+
process_posts(fd_in, fd_out_train, fd_out_test, "<r>")
52+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
dvclive
12
pandas
23
pyaml
34
scikit-learn
45
scipy
6+
matplotlib

example-get-started/code/src/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
min_split = params["min_split"]
2121

2222
with open(os.path.join(input, "train.pkl"), "rb") as fd:
23-
matrix = pickle.load(fd)
23+
matrix, _ = pickle.load(fd)
2424

2525
labels = np.squeeze(matrix[:, 1].toarray())
2626
x = matrix[:, 2:]

example-get-started/deploy.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ popd
1717

1818
# Requires AWS CLI and write access to `s3://dvc-public/code/get-started/`.
1919
mv $PACKAGE_DIR/$PACKAGE .
20-
aws s3 cp --acl public-read $PACKAGE s3://dvc-public/code/get-started/$PACKAGE
20+
#aws s3 cp --acl public-read $PACKAGE s3://dvc-public/code/get-started/$PACKAGE
21+
22+
exit
2123

2224
# Sanity check
2325
wget https://code.dvc.org/get-started/$PACKAGE -O $TEST_PACKAGE

0 commit comments

Comments
 (0)
0