网站首页 > 技术文章 正文
TensorFlow最近才通过livestream(由于COVID-19全球流行病)结束了一年一度的Dev Summit,并发布了许多令人兴奋的公告,其中大多数都集中在将机器学习推向新的高度。从一个强大的新发行版的核心TensorFlow平台(TF2.2)到新的Google Cloud AI平台,使TensorFlow在生产中的使用更加容易,更加广泛。
但这并不是这篇文章的重点。我们将深入探究今年的突破性公告之一:TensorFlow Lite Model Maker。
使用TF Lite Model Maker(它被放入TF Lite support库中)为移动和边缘设备构建模型非常容易。此外,Android Studio 4.1(目前是Canary版本),具有新的针对TF Lite模型的代码生成功能,可以自动生成TF Lite模型的Java包装类,从而简化了移动机器学习开发人员的模型开发和部署过程。
TensorFlow Lite是一个轻量级的跨平台解决方案,用于在移动和嵌入式设备上部署ML模型。如果您想了解今年TF Dev Summit上有关TensorFlow Lite的所有新闻和发布,我肯定建议您查看以下资源:
https://heartbeat.fritz.ai/tensorflow-dev-summit-2020-tensorflow-lite-19dde3153335
TensorFlow Lite Model Maker
TF Lite Model Maker是一个Python API,它使从头构建机器学习模型变得轻而易举。只需要5行代码(不包括imports),如下所示:
是的,我们将使用NSFW数据集。
在上面的演示中,我们加载了一个数据集并将其拆分为训练集和测试集。随后,我们训练,评估,并输出TF Lite model,以及标签(从子文件夹中取出)。
在幕后,Model Maker API使用迁移学习,用不同数据集和类别来重新训练模型。默认情况下,Model Maker API使用EfficientNet-Lite0作为基础模型。
EfficientNet-Lite是最近才发布的,属于能够在边缘设备上实现最新精度的图像分类模型家族。下图显示了EfficientNet-Lite模型的精度与大小比较,并将其与MobileNet和ResNet进行了比较。
来自文档,EfficientNet-Lite针对移动推理进行了优化
Model Maker API还允许我们切换底层模型。例如:
model = image_classifier.create(train_data,
model_spec=mobilenet_v2_spec, validation_data=validation_data)
或者,我们也可以从TensorFlow Hub传递模型,使用自定义的input shape,如下所示:
inception_v3_spec = ImageModelSpec(uri='tfhub_url_goes_here')
inception_v3_spec.input_image_shape = [299, 299]
//pass this spec into model_spec
我们还可以在Model Maker API的create函数中微调训练超参数,如epoch、dropout_rate和batch_size。
model = image_classifier.create(train_data, epochs=10)
现在我们已经很好地了解了modelmakerapi的核心功能,现在让我们收紧运行上述Python脚本所需的依赖关系。
升级TensorFlow
确保您运行的是Python3.6或更高版本,并且在你的macOS上安装了最新的pip版本。TensorFlow 2软件包要求pip版本>19.0。随后,pip安装以下更新TensorFlow:
pip install --user --upgrade tensorflow
注意:建议在virtual环境中安装Python包。使用virtual环境,您可以测试不同版本的库。要了解这个过程,可以查看本教程。
在我们的终端上,我们使用以下命令快速测试是否安装了最新的TensorFlow版本:
python3 -c 'import tensorflow as tf; print(tf.__version__)'
安装Model Maker库
在终端上执行如下命令安装Model Maker库:
pip3 install
git+https://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker]
现在一切都准备好了,这意味着是时候训练我们的模型了。只需从macOS终端运行Python脚本。对于这个演示,我们使用了来自Kaggle的一个适当的NSFW数据集。一旦我们的模型准备好了,就可以导入到我们新的Android Studio项目中了。
- 首先,我们直接从import菜单导入tflite模型,并将其放在ml文件夹中。只需转到File > New > Other > TensorFlow Lite Model。
注意:在这样一个场景中,您不像上面那样导入模型,而是直接将其放在assets文件夹中,模型绑定将被禁用,Android Studio的自动代码生成不会为分类器创建Java包装类,除非您将其移到ml文件夹中。
- 其次,Android Studio现在有一个模型查看器,可以显示元数据摘要-输入和输出张量、这些张量的描述以及示例代码,如下所示:
默认情况下,ModelMaker API只生成由输入和输出形状组成的最小元数据。为了扩展和添加更多的上下文,如作者、版本、许可证以及输入和输出描述,我们可以利用新的扩展元数据特性(目前处于试验阶段)。
打开ML Model binding
尽管将tflite模型放在ml目录中,但model binding也不会自动启用。您需要在您的app的build.gradle 中添加buildFeatures和aaptOptions来打开:
android{
buildFeatures {
mlModelBinding true
}
aaptOptions {
noCompress "tflite"
}
}
我们的模型分类器现在可供我们运行推断。是时候build.gradle文件中添加tensorflow-lite依赖了。
dependencies {
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
implementation 'org.apache.commons:commons-compress:1.19'
}
设置Activity Layout
现在是时候在我们的Activity中放置UI元素了。简单一点,我们的 activity_main.xml文件由一个RecyclerView和一个Button组成:
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
tools:context=".MainActivity">
<androidx.recyclerview.widget.RecyclerView
android:id="@+id/recyclerView"
android:layout_width="match_parent"
android:layout_above="@+id/btnClassifier"
android:layout_height="match_parent" />
<Button
android:id="@+id/btnClassifier"
android:text="Run Classifier"
android:layout_alignParentBottom="true"
android:layout_centerHorizontal="true"
android:layout_width="wrap_content"
android:layout_height="wrap_content"/>
</RelativeLayout>
为了填充RecyclerView的adapter,我们需要一个模型。下面的Kotlin数据类包含一个图像、预测文本,以及一个布尔标志,用于指示输入图像是否为NSFW。
data class DataModel(var drawableID: Int,
var isNSFW: Boolean,
var prediction: String)
以下XML代码表示RecyclerView的每一行的布局:
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal"
android:padding="10dp">
<ImageView
android:id="@+id/imageView"
android:layout_width="100dp"
android:layout_height="100dp"
android:scaleType="centerCrop" />
<TextView
android:id="@+id/tvPrediction"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:gravity="end|center_vertical"
android:textSize="20sp" />
</LinearLayout>
现在我们已经创建了数据模型和视图,是时候将它们提供给RecyclerView的适配器了。
设置RecyclerView的Adapter
以下代码创建RecyclerView的适配器类:
import android.content.Context
import android.graphics.Color
import android.view.LayoutInflater
import android.view.View
import android.view.ViewGroup
import androidx.core.content.ContextCompat
import androidx.recyclerview.widget.RecyclerView
import kotlinx.android.synthetic.main.item_row.view.*
class RecyclerViewAdapter(val items: ArrayList<DataModel>, val context: Context) :
RecyclerView.Adapter<ViewHolder>() {
override fun getItemCount(): Int {
return items.size
}
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ViewHolder {
return ViewHolder(
LayoutInflater.from(context).inflate(
R.layout.item_row,
parent,
false
)
)
}
override fun onBindViewHolder(holder: ViewHolder, position: Int) {
val data = items.get(position)
holder.tvPrediction?.text = data.prediction
val image = ContextCompat.getDrawable(context, data.drawableID)
holder.imageView.setImageDrawable(image)
if (data.isNSFW) {
holder.imageView.setColorFilter(Color.BLACK)
} else {
holder.imageView.setColorFilter(Color.TRANSPARENT)
}
}
}
class ViewHolder(view: View) : RecyclerView.ViewHolder(view) {
val tvPrediction = view.tvPrediction
val imageView = view.imageView
}
我们根据NSFW的输出在ImageView上设置了一个颜色过滤器(NSFW的图像由于明显的原因隐藏在黑色中)。
最后,来到了我们的MainActivity.kt,在这里我们初始化上面的适配器,更重要的是,对图像列表运行推断。
运行TF Lite图像分类器
为了运行模型,我们需要预处理输入以满足模型的要求。TensorFlow lite有大把的内置预处理方法。要使用它们,我们首先需要初始化ImageProcessor,然后添加所需的运算符:
预处理输入图像
在下面的代码中,我们将输入图像的大小调整为224×224,即模型要求的输入形状的尺寸:
val imageProcessor = ImageProcessor.Builder()
.add(ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
.build()
var tImage = TensorImage(DataType.FLOAT32)
tImage.load(bitmap)
tImage = imageProcessor.process(tImage)
TensorImage是喂给我们的TensorFlow Lite模型的输入。但是,在我们运行推断之前,让我们创建一个后处理器来规范化输出概率。
设置后处理器
后处理器基本上是一个容器,它将对结果进行反量化:
val probabilityProcessor =
TensorProcessor.Builder().add(NormalizeOp(0f, 255f)).build()
运行推断
下面几行代码实例化从模型自动生成的分类器,传递输入张量图像,并在outputBuffer中获得结果:
val model = NsfwClassifier.newInstance(this@MainActivity)
val outputs =
model.process(probabilityProcessor.process(tImage.tensorBuffer))
val outputBuffer = outputs.outputFeature0AsTensorBuffer
val tensorLabel = TensorLabel(labelsList, outputBuffer)
TensorLabel用于将相关概率映射到标签。在我们的模型中,只有两个标签:”NSFW”和”SFW”。我们已经把它们设置给labelsList。在不同的场景中,您可以解析labels.txt文件来获取所有类别的信息,比如这里。
最后,使用mapWithFloatValue函数,我们可以获得NSFW和SFW类别的概率。
MainActivity.kt的完整代码如下所示。它在每个图像上运行上述图像分类器,并使用相应的数据更改更新RecyclerView适配器:
class MainActivity : AppCompatActivity() {
val dataArray: ArrayList<DataModel> = ArrayList()
val labelsList = arrayListOf("NSFW", "SFW")
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
btnClassifier.setOnClickListener {
val iterate = dataArray.listIterator()
while (iterate.hasNext()) {
val oldValue = iterate.next()
runImageClassifier(oldValue)
}
recyclerView.adapter?.notifyDataSetChanged()
}
populateData()
recyclerView.layoutManager = LinearLayoutManager(this)
recyclerView.adapter = RecyclerViewAdapter(dataArray, this)
}
fun populateData()
{
dataArray.add(DataModel(R.drawable.sfw_1,true,""))
dataArray.add(DataModel(R.drawable.nsfw,true,""))
dataArray.add(DataModel(R.drawable.nsfw2,true,""))
dataArray.add(DataModel(R.drawable.sfw,true,""))
}
fun runImageClassifier(data: DataModel)
{
val bitmap =
BitmapFactory.decodeResource(applicationContext.resources, data.drawableID)
try {
val probabilityProcessor =
TensorProcessor.Builder().add(NormalizeOp(0f, 255f)).build()
val imageProcessor = ImageProcessor.Builder()
.add(ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
.build()
var tImage = TensorImage(DataType.FLOAT32)
tImage.load(bitmap)
tImage = imageProcessor.process(tImage)
val model = NsfwClassifier.newInstance(this@MainActivity)
val outputs =
model.process(probabilityProcessor.process(tImage.tensorBuffer))
val outputBuffer = outputs.outputFeature0AsTensorBuffer
val tensorLabel = TensorLabel(labelsList, outputBuffer)
val nsfwProbability = tensorLabel.mapWithFloatValue.get("NSFW")
if (nsfwProbability?.compareTo(0.5)!! < 0){
data.isNSFW = false
}
data.prediction = "NSFW : "+ tensorLabel.mapWithFloatValue.get("NSFW")
} catch (e: Exception) {
Log.d("TAG", "Exception is " + e.localizedMessage)
}
}
}
下面是上述应用程序的实际输出:
只有SFW图像被显示
结束语
可以说,Model Maker Python库将继续存在,并将被希望在设备上快速部署ML模型的移动开发人员广泛使用。
对于那些了解苹果机器学习技术的人来说,TF Lite Model Maker与Create ML类似,至少在理论上是这样。目前,Model Maker API只支持图像和文本分类用例,目标检测和QR阅读器预计很快就会推出。
Android Studio对ML模型绑定和自动代码生成的支持消除了与ByteBuffer交互的需要,正如我们在之前的TensorFlow Lite Android教程中所做的那样。
扩展元数据(在编写时处于实验阶段)还允许我们生成定制的、特定于平台的包装器代码,从而进一步减少了我们需要编写的例行代码的数量。在将来的教程中,我们将研究自定义代码生成以及更多内容。
以上教程的完整源代码可以在这个GitHub仓库 https://github.com/anupamchugh/AndroidTFLiteModelMaker (或者码云克隆 https://gitee.com/freehug/anupamchugh-android-tflite-model-maker )中找到。
感谢阅读。
原文链接
猜你喜欢
- 2024-12-28 游戏画面绘图 透明特效的制作方法
- 2024-12-28 Lazarus 打印 raz打印方法
- 2024-12-28 Android 性能优化工具篇:如何使用 DDMS 中的 TraceView 工具
- 2024-12-28 「3D效果图」法线贴图的正确使用方法和技巧
- 2024-12-28 用户界面控件Xtreme Calendar发布v17.0.0
- 2024-12-28 UG各版本安装时出现报警问题及解决方法
- 2024-12-28 记本人使用人工智能辅助编程的实践
- 2024-12-28 6.1 用Bitmap实现精确去重 bitmap字符串去重
- 2024-12-28 MFC常用函数与指令 mfcformat函数
- 2024-12-28 MFC中双缓冲技术 双缓冲技术java
- 02-21走进git时代, 你该怎么玩?_gits
- 02-21GitHub是什么?它可不仅仅是云中的Git版本控制器
- 02-21Git常用操作总结_git基本用法
- 02-21为什么互联网巨头使用Git而放弃SVN?(含核心命令与原理)
- 02-21Git 高级用法,喜欢就拿去用_git基本用法
- 02-21Git常用命令和Git团队使用规范指南
- 02-21总结几个常用的Git命令的使用方法
- 02-21Git工作原理和常用指令_git原理详解
- 最近发表
- 标签列表
-
- cmd/c (57)
- c++中::是什么意思 (57)
- sqlset (59)
- ps可以打开pdf格式吗 (58)
- phprequire_once (61)
- localstorage.removeitem (74)
- routermode (59)
- vector线程安全吗 (70)
- & (66)
- java (73)
- org.redisson (64)
- log.warn (60)
- cannotinstantiatethetype (62)
- js数组插入 (83)
- resttemplateokhttp (59)
- gormwherein (64)
- linux删除一个文件夹 (65)
- mac安装java (72)
- reader.onload (61)
- outofmemoryerror是什么意思 (64)
- flask文件上传 (63)
- eacces (67)
- 查看mysql是否启动 (70)
- java是值传递还是引用传递 (58)
- 无效的列索引 (74)