Skip to content

[mobile][android] Tutorial corrections and aligning with the android api #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 73 additions & 72 deletions _mobile/android.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ published: true

# Android

## Quick start with a HelloWorld example
## Quickstart with a HelloWorld Example

[HelloWorld](https://github.com/pytorch/android-demo-app/tree/master/HelloWorldApp) is a simple image classification application that demonstrates how to use PyTorch android api.
[HelloWorld](https://github.com/pytorch/android-demo-app/tree/master/HelloWorldApp) is a simple image classification application that demonstrates how to use PyTorch Android API.
This application runs TorchScript serialized TorchVision pretrained resnet18 model on static image which is packaged inside the app as android asset.

#### 1. Model preparation
#### 1. Model Preparation

Let’s start with model preparation. If you are familiar with PyTorch, you probably should already know how to train and save your model. In case you don’t, we are going to use a pre-trained image classification model(Resnet18), which is packaged in [TorchVision](https://pytorch.org/docs/stable/torchvision/index.html).
To install it, run the command below:
Expand Down Expand Up @@ -44,13 +44,13 @@ More details about TorchScript you can find in [tutorials on pytorch.org](https:
git clone https://github.com/pytorch/android-demo-app.git
cd HelloWorldApp
```
If [android sdk]() and [android ndk]() are already installed you can install this application to the connected android device or emulator with:
If [Android SDK]() and [Android NDK]() are already installed you can install this application to the connected android device or emulator with:
```
./gradlew installDebug
```

We recommend you to open this project in [Android Studio](https://developer.android.com/studio),
in that case you will be able to install android ndk and android sdk using Android Studio UI.
in that case you will be able to install Android NDK and Android SDK using Android Studio UI.

#### 3. Gradle dependencies

Expand All @@ -66,15 +66,15 @@ dependencies {
implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'
}
```
Where `org.pytorch:pytorch_android` is the main dependency with pytorch android api, including libtorch native library for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64).
Where `org.pytorch:pytorch_android` is the main dependency with PyTorch Android API, including libtorch native library for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64).
Further in this doc you can find how to rebuild it only for specific list of android abis.

`org.pytorch:pytorch_android_torchvision` - additional library with utility functions for converting `android.media.Image` and `android.graphics.Bitmap` to tensors.

#### 4. Reading static image from android asset
#### 4. Reading image from Android Asset

All logic happens in [org.pytorch.helloworld.MainActivity](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/app/src/main/java/org/pytorch/helloworld/MainActivity.java#L31-L69).
As a first step we read `image.jpg` to `android.graphics.Bitmap` using standard android api.
All the logic happens in [`org.pytorch.helloworld.MainActivity`](https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/app/src/main/java/org/pytorch/helloworld/MainActivity.java#L31-L69).
As a first step we read `image.jpg` to `android.graphics.Bitmap` using the standard Android API.
```
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
```
Expand All @@ -85,30 +85,30 @@ Module module = Module.load(assetFilePath(this, "model.pt"));
```
`org.pytorch.Module` represents `torch::jit::script::Module` that can be loaded with `load` method specifying file path to the serialized to file model.

#### 6. Preparing input
#### 6. Preparing Input
```
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capitalization " Preparing Input"

TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
```
`org.pytorch.torchvision.TensorImageUtils` is part of 'org.pytorch:pytorch_android_torchvision' library.
`TensorImageUtils#bitmapToFloat32Tensor` method creates tensor in [torch vision format](https://pytorch.org/docs/stable/torchvision/models.html) using `android.graphics.Bitmap` as a source.
`org.pytorch.torchvision.TensorImageUtils` is part of `org.pytorch:pytorch_android_torchvision` library.
The `TensorImageUtils#bitmapToFloat32Tensor` method creates tensors in the [torchvision format](https://pytorch.org/docs/stable/torchvision/models.html) using `android.graphics.Bitmap` as a source.

> All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224.
> The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
> The images have to be loaded in to a range of `[0, 1]` and then normalized using `mean = [0.485, 0.456, 0.406]` and `std = [0.229, 0.224, 0.225]`

`inputTensor`'s shape is 1x3xHxW, where H and W are bitmap height and width appropriately.
`inputTensor`'s shape is `1x3xHxW`, where `H` and `W` are bitmap height and width appropriately.

#### 7. Run Inference

```
Tensor outputTensor = module.forward(IValue.tensor(inputTensor)).getTensor();
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();
```

`org.pytorch.Module.forward` method runs loaded module's `forward` method and gets result as `org.pytorch.Tensor` outputTensor with shape `1x1000`.

#### 8. Processing results
It's content is retrieved using `org.pytorch.Tensor.getDataAsFloatArray()` method that returns java array of floats with scores for every image net class.
Its content is retrieved using `org.pytorch.Tensor.getDataAsFloatArray()` method that returns java array of floats with scores for every image net class.

After that we just find index with maximum score and retrieve predicted class name from `ImageNetClasses.IMAGENET_CLASSES` array that contains all ImageNet classes.

Expand All @@ -124,14 +124,16 @@ for (int i = 0; i < scores.length; i++) {
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
```

In the following sections you can find detailed explanation of pytorch android api, code walk through for bigger [demo application](https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp), implementation details of api and how to customize and build it from the source.
In the following sections you can find detailed explanations of PyTorch Android API, code walk through for a bigger [demo application](https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp),
implementation details of the API, how to customize and build it from source.

## Pytorch demo app
## PyTorch Demo Application

Bigger example of application that does image classification from android camera output and text classification you can find in the [same github repo](https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp).
We have also created another more complex PyTorch Android demo application that does image classification from camera output and text classification in the [same github repo](https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp).

To get device camera output in it uses [android cameraX api](https://developer.android.com/training/camerax
). All the logic that works with CameraX is separated to [`org.pytorch.demo.vision.AbstractCameraXActivity`](https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/AbstractCameraXActivity.java) class.
To get device camera output it uses [Android CameraX API](https://developer.android.com/training/camerax
).
All the logic that works with CameraX is separated to [`org.pytorch.demo.vision.AbstractCameraXActivity`](https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/AbstractCameraXActivity.java) class.


```
Expand All @@ -158,13 +160,13 @@ void setupCameraX() {
void analyzeImage(android.media.Image, int rotationDegrees)
```

Where `analyzeImage` method processes camera output, `android.media.Image`.
Where the `analyzeImage` method process the camera output, `android.media.Image`.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Where [the] analyzeImage process [the] camera output, android.media.Image. It uses [the] aforementioned..."

It uses aforementioned [`TensorImageUtils.imageYUV420CenterCropToFloat32Tensor`](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java#L90) method to convert `android.media.Image` in `YUV420` format to input tensor.
It uses the aforementioned [`TensorImageUtils.imageYUV420CenterCropToFloat32Tensor`](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java#L90) method to convert `android.media.Image` in `YUV420` format to input tensor.

After getting predicted scores from the model it [finds top K classes](https://github.com/pytorch/android-demo-app/blob/master/PyTorchDemoApp/app/src/main/java/org/pytorch/demo/vision/ImageClassificationActivity.java#L153-L161) with the highest scores and shows on the UI.
After getting predicted scores from the model it finds top K classes with the highest scores and shows on the UI.

## Building pytorch android from source
## Building PyTorch Android from Source

In some cases you might want to use a local build of pytorch android, for example you may build custom libtorch binary with another set of operators or to make local changes.

Expand All @@ -175,19 +177,22 @@ cd pytorch
sh ./scripts/build_pytorch_android.sh
```

Its workflow contains several steps:
1. Builds libtorch for android for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64)
2. Creates symbolic links to the results of those builds:
The workflow contains several steps:

1\. Build libtorch for android for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64)

2\. Create symbolic links to the results of those builds:
`android/pytorch_android/src/main/jniLibs/${abi}` to the directory with output libraries
`android/pytorch_android/src/main/cpp/libtorch_include/${abi}` to the directory with headers. These directories are used to build `libpytorch.so` library that will be loaded on android device.
3. And finally runs `gradle` in `android/pytorch_android` directory with task `assembleRelease`
`android/pytorch_android/src/main/cpp/libtorch_include/${abi}` to the directory with headers. These directories are used to build `libpytorch.so` library that will be loaded on android device.

3\. And finally run `gradle` in `android/pytorch_android` directory with task `assembleRelease`

Script requires that android sdk, android ndk and gradle are installed.
Script requires that Android SDK, Android NDK and gradle are installed.
They are specified as environment variables:

`ANDROID_HOME` - path to [android sdk](https://developer.android.com/studio/command-line/sdkmanager.html)
`ANDROID_HOME` - path to [Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html)

`ANDROID_NDK` - path to [android ndk](https://developer.android.com/studio/projects/install-ndk)
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk)

`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/)

Expand Down Expand Up @@ -226,7 +231,7 @@ dependencies {
}
```

At the moment for the case of using aar files directly we need additional configuration due to packaging specific (libfbjni.so is packaged in both pytorch_android_fbjni.aar and pytorch_android.aar).
At the moment for the case of using aar files directly we need additional configuration due to packaging specific (`libfbjni.so` is packaged in both `pytorch_android_fbjni.aar` and `pytorch_android.aar`).
```
packagingOptions {
pickFirst "**/libfbjni.so"
Expand All @@ -235,52 +240,49 @@ packagingOptions {

## API Details

Main part of java api includes 3 classes:
Main part of java API includes 3 classes:
```
org.pytorch.Module
org.pytorch.IValue
org.pytorch.Tensor
```

If the reader is familiar with pytorch python api, we can think that org.pytorch.Tensor represents torch.tensor, org.pytorch.Module torch.Module<?>, while org.pytorch.IValue represents value of TorchScript variable, supporting all its types. ( https://pytorch.org/docs/stable/jit.html#types )
If the reader is familiar with PyTorch Python API, we can think of `org.pytorch.Tensor` representing `torch.tensor`, `org.pytorch.Module` representing `torch.Module`, and `org.pytorch.IValue` representing the value of the TorchScript variable, supporting all its [types](https://pytorch.org/docs/stable/jit.html#types).

### org.pytorch.Tensor (Tensor)
[github](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/Tensor.java)
### [`org.pytorch.Tensor`]((https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/Tensor.java))

Tensor supports dtypes `uint8, int8, float32, int32, float64, int64`.
Tensor holds data in DirectByteBuffer of proper type with native bit order.

To create a Tensor user can use one of the factory methods:
```
Tensor newUInt8Tensor(long[] shape, ByteBuffer data)
Tensor newUInt8Tensor(long[] shape, byte[] data)

Tensor newInt8Tensor(long[] shape, ByteBuffer data)
Tensor newInt8Tensor(long[] shape, byte[] data)

Tensor newFloat32Tensor(long[] shape, FloatBuffer data)
Tensor newFloat32Tensor(long[] shape, float[] data)
Tensor fromBlobUnsigned(ByteBuffer data, long[] shape)
Tensor fromBlobUnsigned(byte[] data, long[] shape)

Tensor fromBlob(ByteBuffer data, long[] shape)
Tensor fromBlob(byte[] data, long[] shape)

Tensor newInt32Tensor(long[] shape, IntBuffer data)
Tensor newInt32Tensor(long[] shape, int[] data)
Tensor fromBlob(FloatBuffer data, long[] shape)
Tensor fromBlob(float[] data, long[] shape)

Tensor newFloat64Tensor(long[] shape, DoubleBuffer data)
Tensor newFloat64Tensor(long[] shape, double[] data)
Tensor fromBlob(IntBuffer data, long[] shape)
Tensor fromBlob(int[] data, long[] shape)

Tensor fromBlob(DoubleBuffer data, long[] shape)
Tensor fromBlob(double[] data, long[] shape)

Tensor newInt64Tensor(long[] shape, LongBuffer data)
Tensor newInt64Tensor(long[] shape, long[] data)
Tensor fromBlob(LongBuffer data, long[] shape)
Tensor fromBlob(long[] data, long[] shape)
```
Where the first parameter `long[] shape` is shape of the Tensor as array of longs.

Content of the Tensor can be provided either as (a) java array or (b) as java.nio.DirectByteBuffer of proper type with native bit order.
Content of the Tensor can be provided either as (a) java array or (b) as `java.nio.DirectByteBuffer` of proper type with native bit order.

In case of (a) proper DirectByteBuffer will be created internally. (b) case has an advantage that user can keep the reference to DirectByteBuffer and change its content in future for the next run, avoiding allocation of DirectByteBuffer for repeated runs.
In case of (a) proper `DirectByteBuffer` will be created internally. (b) case has an advantage that user can keep the reference to DirectByteBuffer and change its content in future for the next run, avoiding allocation of DirectByteBuffer for repeated runs.

Java’s primitive type byte is signed and java does not have unsigned 8 bit type. For dtype=uint8 api uses byte that will be reinterpretted as uint8 on native side. On java side unsigned value of byte can be read as (byte & 0xFF).
Java’s primitive type byte is signed and java does not have unsigned 8 bit type. For dtype=uint8 java API uses java primitive `byte` that will be reinterpreted as uint8 on native side. On java side unsigned value of byte can be read as `byte & 0xFF`.

#### Tensor content layout
#### Tensor Content Layout

Tensor content is represented as a one dimensional array (buffer),
where the first element has all zero indexes T\[0, ... 0\].
Expand All @@ -293,17 +295,17 @@ Tensor has methods to check its dtype:
```
int dtype()
```
That returns one of the dtype codes:
That returns one of the `DType` enum element:
```
Tensor.DTYPE_UINT8
Tensor.DTYPE_INT8
Tensor.DTYPE_INT32
Tensor.DTYPE_FLOAT32
Tensor.DTYPE_INT64
Tensor.DTYPE_FLOAT64
DType.UINT8
DType.INT8
DType.INT32
DType.FLOAT32
DType.INT64
DType.FLOAT64
```

The data of Tensor can be read as java array:
The data of Tensor can be read as a java array:
```
byte[] getDataAsUnsignedByteArray()
byte[] getDataAsByteArray()
Expand All @@ -312,18 +314,17 @@ long[] getDataAsLongArray()
float[] getDataAsFloatArray()
double[] getDataAsDoubleArray()
```
These methods throw IllegalStateException if called for inappropriate dtype.
These methods throw `IllegalStateException` if called for inappropriate dtype.

### org.pytorch.IValue (IValue)
[github](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/IValue.java)
### [`org.pytorch.IValue`](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/IValue.java)

IValue represents a TorchScript variable that can be one of the supported (by torchscript) types ( https://pytorch.org/docs/stable/jit.html#types ). IValue is a tagged union. For every supported type it has a factory method, method to check the type and a getter method to retrieve a value.
Getters throw IllegalStateException if called for inappropriate type.
IValue represents a TorchScript variable that can be one of the supported (by TorchScript) [types](https://pytorch.org/docs/stable/jit.html#types).
`IValue` is a tagged union. For every supported type it has a factory method, method to check the type and a getter method to retrieve a value.
Getters throw `IllegalStateException` if called for inappropriate type.

### org.pytorch.Module (Module)
[github](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/Module.java)
### [`org.pytorch.Module`](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/java/org/pytorch/Module.java)

Module is a wrapper of torch.jit.ScriptModule (`torch::jit::script::Module` in pytorch c++ api) which can be constructed with factory method load providing absolute path to the file with serialized TorchScript.
Module is a wrapper of torch.jit.ScriptModule (`torch::jit::script::Module` in PyTorch C++ API) which can be constructed with the factory method `load` providing absolute path to the file with serialized TorchScript.
```
IValue IValue.runMethod(String methodName, IValue... inputs)
```
Expand Down