TensorFlow Serving Java是作为TensorFlow Serving的Java API,可以轻松地将基于TensorFlow模型的服务集成到Java应用程序中。从数据预处理到模型训练再到推理,TensorFlow Serving Java提供了完整的端到端解决方案。
一、简介
TensorFlow是深度学习领域最受欢迎的框架之一,但它并不仅仅是一个深度学习框架,同时也是一个支持大规模机器学习的通用计算框架。TensorFlow Serving是通过一个分布式运行的系统来服务于训练好的模型并提供可扩展性的接口,而TensorFlow Serving Java则是为Java应用程序提供直接使用这些模型的接口。
二、使用指南
TensorFlow Serving Java的使用非常方便,下面介绍一下主要的使用方法:
1. 依赖管理
使用TensorFlow Serving Java之前,需要在项目中添加以下依赖:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-serving-api</artifactId>
<version>2.4.0</version>
</dependency>
2. 创建模型Client
使用TensorFlow Serving Java,需要首先创建一个服务器连接的Client对象,可以设置许多参数来配置模型服务的行为。以下是创建客户端的基本示例:
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.servables.common.CleanupAble;
import java.util.Arrays;
import static org.tensorflow.framework.TensorProto.*;
public class TensorFlowService implements CleanupAble {
public TensorFlowService(String host, int port) {
// 创建一个指向指定端口的Stub
}
public float[][] predict(float[][] inputs) {
TensorProto.Builder builder = TensorProto.newBuilder();
builder.setDtype(DataType.DT_FLOAT);
builder.addAllFloatVal(Arrays.stream(inputs).flatMapToDouble(Arrays::stream).collect(Collectors.toList()));
builder.addDim(1);
builder.addDim(inputs.length);
builder.addDim(inputs[0].length);
TensorProto inputProto = builder.build();
// 发送请求
// 通过Stub返回结果
}
@Override
public void close() throws Exception {
// 销毁连接
}
}
3. 加载模型
在创建客户端之前,需要首先启动TensorFlow Serving服务器并加载要使用的模型。以下是在本地连接到TensorFlow Serving服务器并加载预处理好的Batch Normalization的CNN模型的示例代码:
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.serving.*;
import tensorflow.serving.Model;
import tensorflow.serving.Model.VersionPolicy.LatestVersionPolicy;
import tensorflow.serving.SessionServiceGrpc.SessionServiceBlockingStub;
public TensorFlowService(ManagedChannel channel, String modelName, int modelVersion) {
// 创建stub用于与SessionService进行通信
SessionServiceBlockingStub stub = SessionServiceGrpc.newBlockingStub(channel);
// 创建请求
Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder()
.setName(modelName)
.setVersionPolicy(Model.ModelVersionPolicy.newBuilder().setLatest(LatestVersionPolicy.newBuilder()).build())
.build();
GetModelMetadataRequest metadataRequest = GetModelMetadataRequest.newBuilder()
.setModelSpec(modelSpec)
.addMetadataField("signature_def")
.build();
// 获取模型元数据信息
GetModelMetadataResponse metadataResponse = stub.getModelMetadata(metadataRequest);
// 从元数据信息中获取模型输入和输出的名称、形状等
SignatureDefMap signatureDefMap = metadataResponse.getMetadataMap().get("signature_def");
SignatureDef signatureDef = signatureDefMap.getSignatureDefMap().entrySet().stream().findFirst().orElseThrow(RuntimeException::new)
.getValue();
Map.Entry input = signatureDef.getInputsMap().entrySet().stream().findFirst().orElseThrow(RuntimeException::new);
String inputTensorName = input.getKey();
TensorInfo inputTensorInfo = input.getValue();
Map.Entry output = signatureDef.getOutputsMap().entrySet().stream().findFirst().orElseThrow(RuntimeException::new);
String outputTensorName = output.getKey();
TensorInfo outputTensorInfo = output.getValue();
// 创建Session
SessionOptions sessionOptions = SessionOptions.newBuilder()
.setConfig(ConfigProto.newBuilder()
.setGpuOptions(GPUOptions.newBuilder()
.setPerProcessGpuMemoryFraction(0.5))
.build())
.build();
CreateSessionRequest sessionRequest = CreateSessionRequest.newBuilder()
.setSessionConfig(sessionOptions)
.setModelSpec(modelSpec)
.build();
CreateSessionResponse sessionResponse = stub.createSession(sessionRequest);
sessionHandle = sessionResponse.getSessionHandle();
}
4. 发送请求并处理结果
在创建客户端并且加载模型之后,可以开始向模型发送请求并处理模型返回的结果:
public float[][] predict(float[][] inputs) {
TensorProto.Builder builder = TensorProto.newBuilder();
builder.setDtype(DataType.DT_FLOAT);
builder.addAllFloatVal(Arrays.stream(inputs).flatMapToDouble(Arrays::stream).collect(Collectors.toList()));
builder.addDim(1);
builder.addDim(inputs.length);
builder.addDim(inputs[0].length);
TensorProto inputProto = builder.build();
// 创建request对象
RunOptions runOptions = RunOptions.newBuilder()
.setTraceLevel(RunOptions.TraceLevel.NO_TRACE)
.build();
Map inputsMap = Map.of(inputTensorName, inputProto);
Map outputsMap = Map.of(outputTensorName, new OutputTensorInfo(TensorShape.newBuilder().addDim(TensorShape.Dim.newBuilder().setSize(-1).build()).build(), DataType.DT_FLOAT));
RunRequest request = RunRequest.newBuilder()
.setSessionHandle(sessionHandle)
.setRunOptions(runOptions)
.setInputFeed(Maps.transformValues(inputsMap, value -> TensorProtoList.newBuilder().addTensor(value).build()))
.putAllFetch(fetch)
.build();
// 发送请求
RunResponse response = stub.run(request);
// 处理结果
}
三、总结
TensorFlow Serving Java为Java应用程序提供了与TensorFlow模型服务集成的无缝体验。它提供了简单易用的API,使得使用TensorFlow模型服务变得更容易。开发人员只需要简单地加载模型并使用客户端就可以处理输入和输出。TensorFlow Serving Java还提供了灵活的配置选项和可扩展性,并可以与其他TensorFlow项目无缝集成。
原创文章,作者:HIVDG,如若转载,请注明出处:https://www.506064.com/n/375631.html