DJL

 
DJL - Deep Java Library

亚马逊在2019年宣布推出的 开源的 深度学习 开发包,
它是在现有深度学习框架基础上使用原生Java概念构建的开发库
支持MXnet,Tensorflow,Pytorch 

http://docs.djl.ai/engines/pytorch/pytorch-engine/index.html

DJL公共依赖包

 
Sets environment variable: PYTORCH_VERSION to override the default package version.

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
    <version>0.22.1</version>
    <scope>runtime</scope>
</dependency>

DJL Supported PyTorch versions

 
Since DJL 0.14.0, pytorch-engine can load older version of pytorch native library. 

PyTorch engine version 	PyTorch native library version
pytorch-engine:0.22.1 	1.11.0, 1.12.1, 1.13.1, 2.0.0
pytorch-engine:0.21.0 	1.11.0, 1.12.1, 1.13.1
pytorch-engine:0.20.0 	1.11.0, 1.12.1, 1.13.0
pytorch-engine:0.19.0 	1.10.0, 1.11.0, 1.12.1
pytorch-engine:0.18.0 	1.9.1, 1.10.0, 1.11.0
pytorch-engine:0.17.0 	1.9.1, 1.10.0, 1.11.0
pytorch-engine:0.16.0 	1.8.1, 1.9.1, 1.10.0

新的pytorch-engine可以支持旧的pytorch模型,
就是可以向前兼容,
那就尽量下载最新的pytorch-engine

Windows CPU

 
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-cpu</artifactId>
    <classifier>win-x86_64</classifier>
    <scope>runtime</scope>
    <version>2.0.0</version>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-jni</artifactId>
    <version>2.0.0-0.22.1</version>
    <scope>runtime</scope>
</dependency>

Windows GPU

 
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-cu118</artifactId>
    <classifier>win-x86_64</classifier>
    <version>2.0.0</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-jni</artifactId>
    <version>2.0.0-0.22.1</version>
    <scope>runtime</scope>
</dependency>

Linux CPU

 
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-cpu</artifactId>
    <classifier>linux-x86_64</classifier>
    <scope>runtime</scope>
    <version>2.0.0</version>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-jni</artifactId>
    <version>2.0.0-0.22.1</version>
    <scope>runtime</scope>
</dependency>

Linux GPU

 
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-cu118</artifactId>
    <classifier>linux-x86_64</classifier>
    <version>2.0.0</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-jni</artifactId>
    <version>2.0.0-0.22.1</version>
    <scope>runtime</scope>
</dependency>

macOS M1

 
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-cpu</artifactId>
    <classifier>osx-aarch64</classifier>
    <version>2.0.0</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-jni</artifactId>
    <version>2.0.0-0.22.1</version>
    <scope>runtime</scope>
</dependency>

pytorch-model-zoo

Pre-trained models

 
The PyTorch model zoo contains Computer Vision (CV) models. 
All the models are grouped by task under these two categories as follows:
CV
Image Classification
Object Detection
Style Transfer
Image Generation

DJL Model Zoo

 
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-model-zoo</artifactId>
    <version>0.22.1</version>
</dependency>

调用cv示例1

pom

 
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
    xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
  
    <groupId>org.example</groupId>
    <artifactId>w11</artifactId>
    <version>1.0-SNAPSHOT</version>
    <packaging>jar</packaging>
  
    <name>w11</name>
    <url>http://maven.apache.org</url>
  
    <properties>
      <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    </properties>
  
    <dependencies>
      <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>api</artifactId>
        <version>0.22.1</version>
      </dependency>
      <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-engine</artifactId>
        <version>0.22.1</version>
        <scope>runtime</scope>
      </dependency>
      <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-native-cpu</artifactId>
        <classifier>win-x86_64</classifier>
        <scope>runtime</scope>
        <version>2.0.0</version>
      </dependency>
      <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-jni</artifactId>
        <version>2.0.0-0.22.1</version>
        <scope>runtime</scope>
      </dependency>
      <dependency>
        <groupId>org.slf4j</groupId>
        <artifactId>slf4j-nop</artifactId>
        <version>1.7.2</version>
        <type>jar</type>
      </dependency>
      <dependency>
        <groupId>junit</groupId>
        <artifactId>junit</artifactId>
        <version>3.8.1</version>
        <scope>test</scope>
      </dependency>
    </dependencies>
  </project>

code

 
package org.example;

import java.nio.file.*;
import java.awt.image.*;
import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.modality.cv.transform.*;
import ai.djl.modality.cv.translator.*;
import ai.djl.repository.zoo.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;

public class App 
{
    public static void main( String[] args )
    {
        try{
            DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
            DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
            /*
            *torch对应的图像预处理
            preprocess = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
            * */
            Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                    .addTransform(new Resize(256))
                    .addTransform(new CenterCrop(224, 224))
                    .addTransform(new ToTensor())
                    .addTransform(new Normalize(
                            new float[] {0.485f, 0.456f, 0.406f},
                            new float[] {0.229f, 0.224f, 0.225f}))
                    .optApplySoftmax(true)
                    .build();

            Criteria<Image, Classifications> criteria = Criteria.builder()
                    .setTypes(Image.class, Classifications.class)
                    .optModelPath(Paths.get("build/pytorch_models/resnet18"))
                    .optOption("mapLocation", "true") // this model requires mapLocation for GPU
                    .optTranslator(translator)
                    .optProgress(new ProgressBar()).build();

            ZooModel model = criteria.loadModel();

//            Image img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
            //0.png是手与数字识别其中一张0的图片
            Image img = ImageFactory.getInstance().fromUrl("build/pytorch_models/resnet18/0.png");
            img.getWrappedImage();

            Predictor<Image, Classifications> predictor = model.newPredictor();
            Classifications classifications = predictor.predict(img);
            System.out.println(classifications);

        }catch (Exception e){

        }

        System.out.println( "Hello World!" );
    }
}
    
DJL调用pytorch模型

环境描述

 
pytorch:1.10.2 
DJL:0.22.1 
OS:先在windows上开发,然后部署到Cento7上运行

linux 项目创建

 
mvn archetype:generate -DgroupId=org.test -DartifactId=lnx1 -DarchetypeArtifactId=maven-archetype-quickstart -DinteractiveMode=false

python生成pytorch模型

import torch
import torchvision
    
model = torchvision.models.resnet50()
model.eval()

example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("traced_model_resnet50.pt")

注意:不能使用torch.save,要使用torch.jit.trace

参考