8000 Return TensorForest feature importances in inference dict instead of … · jbenjos/tensorflow@5c21d55 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5c21d55

Browse files
Return TensorForest feature importances in inference dict instead of using tf.Print.
Change: 150659388
1 parent 4593809 commit 5c21d55

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

tensorflow/contrib/tensor_forest/client/eval_metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
INFERENCE_PROB_NAME = prediction_key.PredictionKey.CLASSES
3131
INFERENCE_PRED_NAME = prediction_key.PredictionKey.PROBABILITIES
3232

33+
FEATURE_IMPORTANCE_NAME = 'global_feature_importance'
34+
3335

3436
def _top_k_generator(k):
3537
def _top_k(probabilities, targets):

tensorflow/contrib/tensor_forest/client/random_forest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from tensorflow.python.framework import dtypes
2929
from tensorflow.python.framework import ops
3030
from tensorflow.python.ops import control_flow_ops
31-
from tensorflow.python.ops import logging_ops
3231
from tensorflow.python.ops import math_ops
3332
from tensorflow.python.ops import state_ops
3433
from tensorflow.python.platform import tf_logging as logging
@@ -130,6 +129,10 @@ def _model_fn(features, labels, mode):
130129
inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax(
131130
inference[eval_metrics.INFERENCE_PROB_NAME], 1)
132131

132+
if report_feature_importances:
133+
inference[eval_metrics.FEATURE_IMPORTANCE_NAME] = (
134+
graph_builder.feature_importances())
135+
133136
# labels might be None if we're doing prediction (which brings up the
134137
# question of why we force everything to adhere to a single model_fn).
135138
loss_deps = []
@@ -149,10 +152,7 @@ def _model_fn(features, labels, mode):
149152
with ops.control_dependencies(loss_deps):
150153
training_loss = graph_builder.training_loss(
151154
features, labels, name=LOSS_NAME)
152-
if report_feature_importances and mode == model_fn_lib.ModeKeys.EVAL:
153-
training_loss = logging_ops.Print(training_loss,
154-
[graph_builder.feature_importances()],
155-
summarize=1000)
155+
156156
# Put weights back in
157157
if weights is not None:
158158
features[weights_name] = weights

0 commit comments

Comments
 (0)
0