8000 Create FasterRcnnInception.java by pyrator · Pull Request #23 · tensorflow/java-models · GitHub
[go: up one dir, main page]

Skip to content

Comments

Create FasterRcnnInception.java#23

Merged
karllessard merged 4 commits intotensorflow:masterfrom
pyrator:patch-1
Apr 7, 2021
Merged

Create FasterRcnnInception.java#23
karllessard merged 4 commits intotensorflow:masterfrom
pyrator:patch-1

Conversation

@pyrator
Copy link
Contributor
@pyrator pyrator commented Apr 3, 2021

Sample code to load image and perform Object detection with faster_rcnn/inception_resnet_v2_1024x1024

Sample code to load image and perform Object detection with faster_rcnn/inception_resnet_v2_1024x1024
@google-cla google-cla bot added the cla: yes label Apr 3, 2021
Copy link
Collaborator
@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing, I've added a few comments for things that need cleaning up or fixing before we can merge it.

@@ -0,0 +1,302 @@
package org.tensorflow.model.examples.objectdetection;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the copyright/license statement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

So it appears there's a discrepancy between the web page and running saved_model_cli.
*/

import org.tensorflow.*;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer explicit imports over star imports.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

Argument #1
input_tensor: TensorSpec(shape=(1, None, None, 3), dtype=tf.uint8, name='input_tensor')

So it appears there's a discrepancy between the web page and running saved_model_cli.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not clear what the discrepancy is, it just looks like there's a max of 300 objects detected in any given image. Is there something I'm not seeing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry yes
Here's the specifics
the web page states num_detections: a tf.int tensor with only one value, the number of detections [N].
but the actual tensor is DT_FLOAT according to saved_model_cli
also the web page states detection_classes: a tf.int tensor of shape [N] containing detection class index from the label file.
but again the actual tensor is DT_FLOAT according to saved_model_cli
My code is therefore expecting TFloat32

package org.tensorflow.model.examples.objectdetection;
/*

download the model from https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_1024x1024/1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should be cleaned up to make it clear how to use the example. Or alternatively put this text in a markdown file and reference the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

import java.util.Map;
import java.util.TreeMap;

public class FasterRcnnInception {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some javadoc for this class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

TFloat32 detectionClasses = (TFloat32) outputTensorMap.get("detection_classes");
TFloat32 detectionBoxes = (TFloat32) outputTensorMap.get("detection_boxes");
TFloat32 numDetections = (TFloat32) outputTensorMap.get("num_detections");
TFloat32 detectionScores = (TFloat32) outputTensorMap.get("detection_scores");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tensors aren't closed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

long[] shapeArray = imageShape.asArray();

//The given SavedModel SignatureDef input
Map<String, Tensor> feed_dict = new HashMap<>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use Java style names (e.g. feedDict not feed_dict).

* @param tUint8Tensor 3D tensor
* @return 4D tensor
*/
private static TUint8 reshapeTensor(TUint8 tUint8Tensor) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will create things in the default eager session, when it might be better to add this as part of the graph. You'll need to close the output here as well as the default eager session is long lived.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes will do. I think I had copied the way in which similar functionality was used in the SimpleMnist example

).asTensor();
}

private static TUint8 drawBoundingBoxes(TUint8 images, TFloat32 boxes, TFloat32 colors ){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same comment as reshapeTensor wrt the default eager session.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the drawBoundingBoxes function for the present until I've fully implemented it. It's part of my todo list.

}
//TODO tf.image.encodeJpeg
ImageIO.write(bufferedImage, "jpg", new File("image2rcnn.jpg"));
} catch (IOException e) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please log something at least if it throws IOException, otherwise this could produce no output and be confusing for people who run it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

@karllessard
Copy link
Collaborator

Thanks @pyrator ! Would it make sense to move this new example under cnn/fastrcnn instead of objectdetection? We already have other type of CNN examples under that directory.

@pyrator
Copy link
Contributor Author
pyrator commented Apr 6, 2021

@karllessard sorry - missed your latest comment. If everyone is happy that my latest commit I can move to cnn/fastrcnn package

Copy link
Collaborator
@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can repackage it too per Karl's suggestion. It might also be nice to accept two file paths as input arguments to main, one for the input file and one for the output file, but that's not strictly necessary.

@@ -1,11 +1,25 @@
package org.tensorflow.model.examples.objectdetection;
/*
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be "Copyright 2021".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

//The given SavedModel SignatureDef input
Map<String, Tensor> feed_dict = new HashMap<>();
feed_dict.put("input_tensor", reshapeTensor(outputImage));
feedDict.put("input_tensor", reshapeTensor);//reshapeTensor(outputImage));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove the commented out code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

outputImage.shape().asArray()[2]
)
);
TUint8 reshapeTensor = (TUint8) runner.fetch(reshape).run().get(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should probably use a fresh runner here, so it doesn't re-execute the decode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

feedDict.put("input_tensor", reshapeTensor);//reshapeTensor(outputImage));
//The given SavedModel MetaGraphDef key
Map<String, Tensor> outputTensorMap = model.function("serving_default").call(feed_dict);
Map<String, Tensor> outputTensorMap = model.function("serving_default").call(feedDict);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is going to leak the remaining tensors, as there are more outputs that you aren't closing. At the moment the output of a function isn't autocloseable, so you'll need to iterate the map and close each element.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

Changed copyright year, added parameters for input and output images, close tensors, created new runner




Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: that's a lot of empty lines, can you please reformat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tidied up abit

}

reshapeTensor.close();
outputImage.close();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a tutorial, I think we should show the best approach for closing tensors, which is with try-with-resources. I saw that in a previous commit, you were encapsulating each tensors resulting from the function call in such block, i.e.

            //detection_classes is a model output name
            try (TFloat32 detectionClasses = (TFloat32) outputTensorMap.get("detection_classes");
                 TFloat32 detectionBoxes = (TFloat32) outputTensorMap.get("detection_boxes");
                 TFloat32 numDetections = (TFloat32) outputTensorMap.get("num_detections");
                 TFloat32 detectionScores = (TFloat32) outputTensorMap.get("detection_scores")) {

I'd suggest you revert to that. reshapeTensor and outputImage should also be encapsulated, e.g.

try (TUint8 reshapeTensor = (TUint8) runner.fetch(reshape).run().get(0)) { 
   ... 
}

Even what I'm proposing here does not guarantee that all resources will be closed (since we assume that we fetch all tensors from a session run or a function call), but it is better than closing explicitly like in the current example, which will never happen if an exception in thrown before (so someone copy&pasting that example in a Java function could endup with memory leaks).

We are actually working on a better solution for closing all resources properly without all this headache.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made sure that all the tensors created by the model function call are encapsulated in the try block so hopefully they will all be closed. I've also encapsulated the other two tensors.

Encapsulated all tensors and a further tidy up
Copy link
Collaborator
@karllessard karllessard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @pyrator !

@karllessard karllessard merged commit a9d975b into tensorflow:master Apr 7, 2021
@pyrator pyrator deleted the patch-1 branch April 7, 2021 21:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

0