-
Notifications
You must be signed in to change notification settings - Fork 307
[mobile][android] Tutorial corrections and aligning with the android api #267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
brucejlin1
merged 1 commit into
pytorch:mobile-launch
from
IvanKobzarev:ik_android_tutorial_fixes_1007
Oct 9, 2019
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,12 +10,12 @@ published: true | |
|
||
# Android | ||
|
||
## Quick start with a HelloWorld example | ||
## Quickstart with a HelloWorld Example | ||
|
||
[HelloWorld](https://github.com/pytorch/android-demo-app/tree/master/HelloWorldApp) is a simple image classification application that demonstrates how to use PyTorch android api. | ||
[HelloWorld](https://github.com/pytorch/android-demo-app/tree/master/HelloWorldApp) is a simple image classification application that demonstrates how to use PyTorch Android API. | ||
This application runs TorchScript serialized TorchVision pretrained resnet18 model on static image which is packaged inside the app as android asset. | ||
|
||
#### 1. Model preparation | ||
#### 1. Model Preparation | ||
|
||
Let’s start with model preparation. If you are familiar with PyTorch, you probably should already know how to train and save your model. In case you don’t, we are going to use a pre-trained image classification model(Resnet18), which is packaged in [TorchVision](https://pytorch.org/docs/stable/torchvision/index.html). | ||
To install it, run the command below: | ||
|
@@ -44,13 +44,13 @@ More details about TorchScript you can find in [tutorials on pytorch.org](https: | |
git clone https://github.com/pytorch/android-demo-app.git | ||
cd HelloWorldApp | ||
``` | ||
If [android sdk]() and [android ndk]() are already installed you can install this application to the connected android device or emulator with: | ||
If [Android SDK]() and [Android NDK]() are already installed you can install this application to the connected android device or emulator with: | ||
``` | ||
./gradlew installDebug | ||
``` | ||
|
||
We recommend you to open this project in [Android Studio](https://developer.android.com/studio), | ||
in that case you will be able to install android ndk and android sdk using Android Studio UI. | ||
in that case you will be able to install Android NDK and Android SDK using Android Studio UI. | ||
|
||
#### 3. Gradle dependencies | ||
|
||
|
@@ -66,15 +66,15 @@ dependencies { | |
implementation 'org.pytorch:pytorch_android_torchvision:1.3.0' | ||
} | ||
``` | ||
Where `org.pytorch:pytorch_android` is the main dependency with pytorch android api, including libtorch native library for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64). | ||
Where `org.pytorch:pytorch_android` is the main dependency with PyTorch Android API, including libtorch native library for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64). | ||
Further in this doc you can find how to rebuild it only for specific list of android abis. | ||
|
||
`org.pytorch:pytorch_android_torchvision` - additional library with utility functions for converting `android.media.Image` and `android.graphics.Bitmap` to tensors. | ||
|
||
#### 4. Reading static image from android asset | ||
#### 4. Reading image from Android Asset | ||
|
||
All logic happens in [org.pytorch.helloworld.MainActivity](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/app/src/main/java/org/pytorch/helloworld/MainActivity.java#L31-L69). | ||
As a first step we read `image.jpg` to `android.graphics.Bitmap` using standard android api. | ||
All the logic happens in [`org.pytorch.helloworld.MainActivity`](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/app/src/main/java/org/pytorch/helloworld/MainActivity.java#L31-L69). | ||
As a first step we read `image.jpg` to `android.graphics.Bitmap` using the standard Android API. | ||
``` | ||
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg")); | ||
``` | ||
|
@@ -85,30 +85,30 @@ Module module = Module.load(assetFilePath(this, "model.pt")); | |
``` | ||
`org.pytorch.Module` represents `torch::jit::script::Module` that can be loaded with `load` method specifying file path to the serialized to file model. | ||
|
||
#### 6. Preparing input | ||
#### 6. Preparing Input | ||
``` | ||
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, | ||
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB); | ||
``` | ||
`org.pytorch.torchvision.TensorImageUtils` is part of 'org.pytorch:pytorch_android_torchvision' library. | ||
`TensorImageUtils#bitmapToFloat32Tensor` method creates tensor in [torch vision format](https://pytorch.org/docs/stable/torchvision/models.html) using `android.graphics.Bitmap` as a source. | ||
`org.pytorch.torchvision.TensorImageUtils` is part of `org.pytorch:pytorch_android_torchvision` library. | ||
The `TensorImageUtils#bitmapToFloat32Tensor` method creates tensors in the [torchvision format](https://pytorch.org/docs/stable/torchvision/models.html) using `android.graphics.Bitmap` as a source. | ||
|
||
> All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. | ||
> The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] | ||
> The images have to be loaded in to a range of `[0, 1]` and then normalized using `mean = [0.485, 0.456, 0.406]` and `std = [0.229, 0.224, 0.225]` | ||
|
||
`inputTensor`'s shape is 1x3xHxW, where H and W are bitmap height and width appropriately. | ||
`inputTensor`'s shape is `1x3xHxW`, where `H` and `W` are bitmap height and width appropriately. | ||
|
||
#### 7. Run Inference | ||
|
||
``` | ||
Tensor outputTensor = module.forward(IValue.tensor(inputTensor)).getTensor(); | ||
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); | ||
float[] scores = outputTensor.getDataAsFloatArray(); | ||
``` | ||
|
||
`org.pytorch.Module.forward` method runs loaded module's `forward` method and gets result as `org.pytorch.Tensor` outputTensor with shape `1x1000`. | ||
|
||
#### 8. Processing results | ||
It's content is retrieved using `org.pytorch.Tensor.getDataAsFloatArray()` method that returns java array of floats with scores for every image net class. | ||
Its content is retrieved using `org.pytorch.Tensor.getDataAsFloatArray()` method that returns java array of floats with scores for every image net class. | ||
|
||
After that we just find index with maximum score and retrieve predicted class name from `ImageNetClasses.IMAGENET_CLASSES` array that contains all ImageNet classes. | ||
|
||
|
@@ -124,14 +124,16 @@ for (int i = 0; i < scores.length; i++) { | |
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx]; | ||
``` | ||
|
||
In the following sections you can find detailed explanation of pytorch android api, code walk through for bigger [demo application](https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp), implementation details of api and how to customize and build it from the source. | ||
In the following sections you can find detailed explanations of PyTorch Android API, code walk through for a bigger [demo application](https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp), | ||
implementation details of the API, how to customize and build it from source. | ||
|
||
## Pytorch demo app | ||
## PyTorch Demo Application | ||
|
||
Bigger example of application that does image classification from android camera output and text classification you can find in the [same github repo](https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp). | ||
We have also created another more complex PyTorch Android demo application that does image classification from camera output and text classification in the [same github repo](https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp). | ||
|
||
To get device camera output in it uses [android cameraX api](https://developer.android.com/training/camerax | ||
). All the logic that works with CameraX is separated to [`org.pytorch.demo.vision.AbstractCameraXActivity`](https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/AbstractCameraXActivity.java) class. | ||
To get device camera output it uses [Android CameraX API](https://developer.android.com/training/camerax | ||
). | ||
All the logic that works with CameraX is separated to [`org.pytorch.demo.vision.AbstractCameraXActivity`](https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/AbstractCameraXActivity.java) class. | ||
|
||
|
||
``` | ||
|
@@ -158,13 +160,13 @@ void setupCameraX() { | |
void analyzeImage(android.media.Image, int rotationDegrees) | ||
``` | ||
|
||
Where `analyzeImage` method processes camera output, `android.media.Image`. | ||
Where the `analyzeImage` method process the camera output, `android.media.Image`. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Where [the] analyzeImage process [the] camera output, android.media.Image. It uses [the] aforementioned..." |
||
It uses aforementioned [`TensorImageUtils.imageYUV420CenterCropToFloat32Tensor`](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java#L90) method to convert `android.media.Image` in `YUV420` format to input tensor. | ||
It uses the aforementioned [`TensorImageUtils.imageYUV420CenterCropToFloat32Tensor`](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java#L90) method to convert `android.media.Image` in `YUV420` format to input tensor. | ||
|
||
After getting predicted scores from the model it [finds top K classes](https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/ImageClassificationActivity.java#L153-L161) with the highest scores and shows on the UI. | ||
After getting predicted scores from the model it finds top K classes with the highest scores and shows on the UI. | ||
|
||
## Building pytorch android from source | ||
## Building PyTorch Android from Source | ||
|
||
In some cases you might want to use a local build of pytorch android, for example you may build custom libtorch binary with another set of operators or to make local changes. | ||
|
||
|
@@ -175,19 +177,22 @@ cd pytorch | |
sh ./scripts/build_pytorch_android.sh | ||
``` | ||
|
||
Its workflow contains several steps: | ||
1. Builds libtorch for android for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64) | ||
2. Creates symbolic links to the results of those builds: | ||
The workflow contains several steps: | ||
|
||
1\. Build libtorch for android for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64) | ||
|
||
2\. Create symbolic links to the results of those builds: | ||
`android/pytorch_android/src/main/jniLibs/${abi}` to the directory with output libraries | ||
`android/pytorch_android/src/main/cpp/libtorch_include/${abi}` to the directory with headers. These directories are used to build `libpytorch.so` library that will be loaded on android device. | ||
3. And finally runs `gradle` in `android/pytorch_android` directory with task `assembleRelease` | ||
`android/pytorch_android/src/main/cpp/libtorch_include/${abi}` to the directory with headers. These directories are used to build `libpytorch.so` library that will be loaded on android device. | ||
|
||
3\. And finally run `gradle` in `android/pytorch_android` directory with task `assembleRelease` | ||
|
||
Script requires that android sdk, android ndk and gradle are installed. | ||
Script requires that Android SDK, Android NDK and gradle are installed. | ||
They are specified as environment variables: | ||
|
||
`ANDROID_HOME` - path to [android sdk](https://developer.android.com/studio/command-line/sdkmanager.html) | ||
`ANDROID_HOME` - path to [Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html) | ||
|
||
`ANDROID_NDK` - path to [android ndk](https://developer.android.com/studio/projects/install-ndk) | ||
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk) | ||
|
||
`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/) | ||
|
||
|
@@ -226,7 +231,7 @@ dependencies { | |
} | ||
``` | ||
|
||
At the moment for the case of using aar files directly we need additional configuration due to packaging specific (libfbjni.so is packaged in both pytorch_android_fbjni.aar and pytorch_android.aar). | ||
At the moment for the case of using aar files directly we need additional configuration due to packaging specific (`libfbjni.so` is packaged in both `pytorch_android_fbjni.aar` and `pytorch_android.aar`). | ||
``` | ||
packagingOptions { | ||
pickFirst "**/libfbjni.so" | ||
|
@@ -235,52 +240,49 @@ packagingOptions { | |
|
||
## API Details | ||
|
||
Main part of java api includes 3 classes: | ||
Main part of java API includes 3 classes: | ||
``` | ||
org.pytorch.Module | ||
org.pytorch.IValue | ||
org.pytorch.Tensor | ||
``` | ||
|
||
If the reader is familiar with pytorch python api, we can think that org.pytorch.Tensor represents torch.tensor, org.pytorch.Module torch.Module<?>, while org.pytorch.IValue represents value of TorchScript variable, supporting all its types. ( https://pytorch.org/docs/stable/jit.html#types ) | ||
If the reader is familiar with PyTorch Python API, we can think of `org.pytorch.Tensor` representing `torch.tensor`, `org.pytorch.Module` representing `torch.Module`, and `org.pytorch.IValue` representing the value of the TorchScript variable, supporting all its [types](https://pytorch.org/docs/stable/jit.html#types). | ||
|
||
### org.pytorch.Tensor (Tensor) | ||
[github](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/Tensor.java) | ||
### [`org.pytorch.Tensor`]((https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/Tensor.java)) | ||
|
||
Tensor supports dtypes `uint8, int8, float32, int32, float64, int64`. | ||
Tensor holds data in DirectByteBuffer of proper type with native bit order. | ||
|
||
To create a Tensor user can use one of the factory methods: | ||
``` | ||
Tensor newUInt8Tensor(long[] shape, ByteBuffer data) | ||
Tensor newUInt8Tensor(long[] shape, byte[] data) | ||
|
||
Tensor newInt8Tensor(long[] shape, ByteBuffer data) | ||
Tensor newInt8Tensor(long[] shape, byte[] data) | ||
|
||
Tensor newFloat32Tensor(long[] shape, FloatBuffer data) | ||
Tensor newFloat32Tensor(long[] shape, float[] data) | ||
Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) | ||
Tensor fromBlobUnsigned(byte[] data, long[] shape) | ||
|
||
Tensor fromBlob(ByteBuffer data, long[] shape) | ||
Tensor fromBlob(byte[] data, long[] shape) | ||
|
||
Tensor newInt32Tensor(long[] shape, IntBuffer data) | ||
Tensor newInt32Tensor(long[] shape, int[] data) | ||
Tensor fromBlob(FloatBuffer data, long[] shape) | ||
Tensor fromBlob(float[] data, long[] shape) | ||
|
||
Tensor newFloat64Tensor(long[] shape, DoubleBuffer data) | ||
Tensor newFloat64Tensor(long[] shape, double[] data) | ||
Tensor fromBlob(IntBuffer data, long[] shape) | ||
Tensor fromBlob(int[] data, long[] shape) | ||
|
||
Tensor fromBlob(DoubleBuffer data, long[] shape) | ||
Tensor fromBlob(double[] data, long[] shape) | ||
|
||
Tensor newInt64Tensor(long[] shape, LongBuffer data) | ||
Tensor newInt64Tensor(long[] shape, long[] data) | ||
Tensor fromBlob(LongBuffer data, long[] shape) | ||
Tensor fromBlob(long[] data, long[] shape) | ||
``` | ||
Where the first parameter `long[] shape` is shape of the Tensor as array of longs. | ||
|
||
Content of the Tensor can be provided either as (a) java array or (b) as java.nio.DirectByteBuffer of proper type with native bit order. | ||
Content of the Tensor can be provided either as (a) java array or (b) as `java.nio.DirectByteBuffer` of proper type with native bit order. | ||
|
||
In case of (a) proper DirectByteBuffer will be created internally. (b) case has an advantage that user can keep the reference to DirectByteBuffer and change its content in future for the next run, avoiding allocation of DirectByteBuffer for repeated runs. | ||
In case of (a) proper `DirectByteBuffer` will be created internally. (b) case has an advantage that user can keep the reference to DirectByteBuffer and change its content in future for the next run, avoiding allocation of DirectByteBuffer for repeated runs. | ||
|
||
Java’s primitive type byte is signed and java does not have unsigned 8 bit type. For dtype=uint8 api uses byte that will be reinterpretted as uint8 on native side. On java side unsigned value of byte can be read as (byte & 0xFF). | ||
Java’s primitive type byte is signed and java does not have unsigned 8 bit type. For dtype=uint8 java API uses java primitive `byte` that will be reinterpreted as uint8 on native side. On java side unsigned value of byte can be read as `byte & 0xFF`. | ||
|
||
#### Tensor content layout | ||
#### Tensor Content Layout | ||
|
||
Tensor content is represented as a one dimensional array (buffer), | ||
where the first element has all zero indexes T\[0, ... 0\]. | ||
|
@@ -293,17 +295,17 @@ Tensor has methods to check its dtype: | |
``` | ||
int dtype() | ||
``` | ||
That returns one of the dtype codes: | ||
That returns one of the `DType` enum element: | ||
``` | ||
Tensor.DTYPE_UINT8 | ||
Tensor.DTYPE_INT8 | ||
Tensor.DTYPE_INT32 | ||
Tensor.DTYPE_FLOAT32 | ||
Tensor.DTYPE_INT64 | ||
Tensor.DTYPE_FLOAT64 | ||
DType.UINT8 | ||
DType.INT8 | ||
DType.INT32 | ||
DType.FLOAT32 | ||
DType.INT64 | ||
DType.FLOAT64 | ||
``` | ||
|
||
The data of Tensor can be read as java array: | ||
The data of Tensor can be read as a java array: | ||
``` | ||
byte[] getDataAsUnsignedByteArray() | ||
byte[] getDataAsByteArray() | ||
|
@@ -312,18 +314,17 @@ long[] getDataAsLongArray() | |
float[] getDataAsFloatArray() | ||
double[] getDataAsDoubleArray() | ||
``` | ||
These methods throw IllegalStateException if called for inappropriate dtype. | ||
These methods throw `IllegalStateException` if called for inappropriate dtype. | ||
|
||
### org.pytorch.IValue (IValue) | ||
[github](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/IValue.java) | ||
### [`org.pytorch.IValue`](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/IValue.java) | ||
|
||
IValue represents a TorchScript variable that can be one of the supported (by torchscript) types ( https://pytorch.org/docs/stable/jit.html#types ). IValue is a tagged union. For every supported type it has a factory method, method to check the type and a getter method to retrieve a value. | ||
Getters throw IllegalStateException if called for inappropriate type. | ||
IValue represents a TorchScript variable that can be one of the supported (by TorchScript) [types](https://pytorch.org/docs/stable/jit.html#types). | ||
`IValue` is a tagged union. For every supported type it has a factory method, method to check the type and a getter method to retrieve a value. | ||
Getters throw `IllegalStateException` if called for inappropriate type. | ||
|
||
### org.pytorch.Module (Module) | ||
[github](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/Module.java) | ||
### [`org.pytorch.Module`](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/Module.java) | ||
|
||
Module is a wrapper of torch.jit.ScriptModule (`torch::jit::script::Module` in pytorch c++ api) which can be constructed with factory method load providing absolute path to the file with serialized TorchScript. | ||
Module is a wrapper of torch.jit.ScriptModule (`torch::jit::script::Module` in PyTorch C++ API) which can be constructed with the factory method `load` providing absolute path to the file with serialized TorchScript. | ||
``` | ||
IValue IValue.runMethod(String methodName, IValue... inputs) | ||
``` | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Capitalization " Preparing Input"