DJL - Deep Java Library 亚马逊在2019年宣布推出的 开源的 深度学习 开发包, 它是在现有深度学习框架基础上使用原生Java概念构建的开发库 支持MXnet,Tensorflow,Pytorch http://docs.djl.ai/engines/pytorch/pytorch-engine/index.html
DJL公共依赖包
Sets environment variable: PYTORCH_VERSION to override the default package version. <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.22.1</version> <scope>runtime</scope> </dependency>
DJL Supported PyTorch versions
Since DJL 0.14.0, pytorch-engine can load older version of pytorch native library. PyTorch engine version PyTorch native library version pytorch-engine:0.22.1 1.11.0, 1.12.1, 1.13.1, 2.0.0 pytorch-engine:0.21.0 1.11.0, 1.12.1, 1.13.1 pytorch-engine:0.20.0 1.11.0, 1.12.1, 1.13.0 pytorch-engine:0.19.0 1.10.0, 1.11.0, 1.12.1 pytorch-engine:0.18.0 1.9.1, 1.10.0, 1.11.0 pytorch-engine:0.17.0 1.9.1, 1.10.0, 1.11.0 pytorch-engine:0.16.0 1.8.1, 1.9.1, 1.10.0 新的pytorch-engine可以支持旧的pytorch模型, 就是可以向前兼容, 那就尽量下载最新的pytorch-engine
Windows CPU
<dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>win-x86_64</classifier> <scope>runtime</scope> <version>2.0.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.0-0.22.1</version> <scope>runtime</scope> </dependency>
Windows GPU
<dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cu118</artifactId> <classifier>win-x86_64</classifier> <version>2.0.0</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.0-0.22.1</version> <scope>runtime</scope> </dependency>
Linux CPU
<dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>linux-x86_64</classifier> <scope>runtime</scope> <version>2.0.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.0-0.22.1</version> <scope>runtime</scope> </dependency>
Linux GPU
<dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cu118</artifactId> <classifier>linux-x86_64</classifier> <version>2.0.0</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.0-0.22.1</version> <scope>runtime</scope> </dependency>
macOS M1
<dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>osx-aarch64</classifier> <version>2.0.0</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.0-0.22.1</version> <scope>runtime</scope> </dependency>
Pre-trained models
The PyTorch model zoo contains Computer Vision (CV) models. All the models are grouped by task under these two categories as follows: CV Image Classification Object Detection Style Transfer Image Generation DJL Model Zoo
<dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-model-zoo</artifactId> <version>0.22.1</version> </dependency>
pom
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>org.example</groupId> <artifactId>w11</artifactId> <version>1.0-SNAPSHOT</version> <packaging>jar</packaging> <name>w11</name> <url>http://maven.apache.org</url> <properties> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> </properties> <dependencies> <dependency> <groupId>ai.djl</groupId> <artifactId>api</artifactId> <version>0.22.1</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.22.1</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>win-x86_64</classifier> <scope>runtime</scope> <version>2.0.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>2.0.0-0.22.1</version> <scope>runtime</scope> </dependency> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-nop</artifactId> <version>1.7.2</version> <type>jar</type> </dependency> <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <version>3.8.1</version> <scope>test</scope> </dependency> </dependencies> </project>
code
package org.example; import java.nio.file.*; import java.awt.image.*; import ai.djl.*; import ai.djl.inference.*; import ai.djl.modality.*; import ai.djl.modality.cv.*; import ai.djl.modality.cv.util.*; import ai.djl.modality.cv.transform.*; import ai.djl.modality.cv.translator.*; import ai.djl.repository.zoo.*; import ai.djl.translate.*; import ai.djl.training.util.*; public class App { public static void main( String[] args ) { try{ DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar()); DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar()); /* *torch对应的图像预处理 preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) * */ Translator<Image, Classifications> translator = ImageClassificationTranslator.builder() .addTransform(new Resize(256)) .addTransform(new CenterCrop(224, 224)) .addTransform(new ToTensor()) .addTransform(new Normalize( new float[] {0.485f, 0.456f, 0.406f}, new float[] {0.229f, 0.224f, 0.225f})) .optApplySoftmax(true) .build(); Criteria<Image, Classifications> criteria = Criteria.builder() .setTypes(Image.class, Classifications.class) .optModelPath(Paths.get("build/pytorch_models/resnet18")) .optOption("mapLocation", "true") // this model requires mapLocation for GPU .optTranslator(translator) .optProgress(new ProgressBar()).build(); ZooModel model = criteria.loadModel(); // Image img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png"); //0.png是手与数字识别其中一张0的图片 Image img = ImageFactory.getInstance().fromUrl("build/pytorch_models/resnet18/0.png"); img.getWrappedImage(); Predictor<Image, Classifications> predictor = model.newPredictor(); Classifications classifications = predictor.predict(img); System.out.println(classifications); }catch (Exception e){ } System.out.println( "Hello World!" ); } }
环境描述
pytorch:1.10.2 DJL:0.22.1 OS:先在windows上开发,然后部署到Cento7上运行
linux 项目创建
mvn archetype:generate -DgroupId=org.test -DartifactId=lnx1 -DarchetypeArtifactId=maven-archetype-quickstart -DinteractiveMode=false
python生成pytorch模型
import torch import torchvision model = torchvision.models.resnet50() model.eval() example = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, example) traced_script_module.save("traced_model_resnet50.pt") 注意:不能使用torch.save,要使用torch.jit.trace