Skip to content

Commit 3a432cd

Browse files
authored
Use ktfmt to format code (#11061)
Use ktfmt. The format is a bit different from android studio. Ideally we do some check in CI
1 parent ed718a8 commit 3a432cd

File tree

7 files changed

+524
-491
lines changed

7 files changed

+524
-491
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import android.Manifest
1111
import androidx.test.InstrumentationRegistry
1212
import androidx.test.ext.junit.runners.AndroidJUnit4
1313
import androidx.test.rule.GrantPermissionRule
14+
import java.io.File
15+
import java.io.IOException
16+
import java.net.URISyntaxException
1417
import org.apache.commons.io.FileUtils
1518
import org.json.JSONException
1619
import org.json.JSONObject
@@ -21,11 +24,8 @@ import org.junit.Test
2124
import org.junit.runner.RunWith
2225
import org.pytorch.executorch.extension.llm.LlmCallback
2326
import org.pytorch.executorch.extension.llm.LlmModule
24-
import java.io.File
25-
import java.io.IOException
26-
import java.net.URISyntaxException
2727

28-
/** Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */
28+
/** Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */
2929
@RunWith(AndroidJUnit4::class)
3030
class LlmModuleInstrumentationTest : LlmCallback {
3131
private val results: MutableList<String> = ArrayList()
@@ -69,16 +69,20 @@ class LlmModuleInstrumentationTest : LlmCallback {
6969
@Test
7070
@Throws(IOException::class, URISyntaxException::class)
7171
fun testGenerateAndStop() {
72-
llmModule!!.generate(TEST_PROMPT, SEQ_LEN, object : LlmCallback {
73-
override fun onResult(result: String) {
74-
this@LlmModuleInstrumentationTest.onResult(result)
75-
llmModule!!.stop()
76-
}
72+
llmModule!!.generate(
73+
TEST_PROMPT,
74+
SEQ_LEN,
75+
object : LlmCallback {
76+
override fun onResult(result: String) {
77+
this@LlmModuleInstrumentationTest.onResult(result)
78+
llmModule!!.stop()
79+
}
7780

78-
override fun onStats(stats: String) {
79-
this@LlmModuleInstrumentationTest.onStats(stats)
80-
}
81-
})
81+
override fun onStats(stats: String) {
82+
this@LlmModuleInstrumentationTest.onStats(stats)
83+
}
84+
},
85+
)
8286

8387
val stoppedResultSize = results.size
8488
Assert.assertTrue(stoppedResultSize < SEQ_LEN)
@@ -97,8 +101,7 @@ class LlmModuleInstrumentationTest : LlmCallback {
97101
val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms")
98102
tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
99103
tokensPerSecond.add(tps)
100-
} catch (_: JSONException) {
101-
}
104+
} catch (_: JSONException) {}
102105
}
103106

104107
companion object {
@@ -109,7 +112,10 @@ class LlmModuleInstrumentationTest : LlmCallback {
109112
private const val SEQ_LEN = 32
110113

111114
private fun getTestFilePath(fileName: String): String {
112-
return InstrumentationRegistry.getInstrumentation().targetContext.externalCacheDir.toString() + fileName
115+
return InstrumentationRegistry.getInstrumentation()
116+
.targetContext
117+
.externalCacheDir
118+
.toString() + fileName
113119
}
114120
}
115121
}

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@ import android.graphics.BitmapFactory
1313
import androidx.test.InstrumentationRegistry
1414
import androidx.test.ext.junit.runners.AndroidJUnit4
1515
import androidx.test.rule.GrantPermissionRule
16+
import java.io.File
17+
import java.io.IOException
18+
import java.net.URISyntaxException
1619
import org.apache.commons.io.FileUtils
1720
import org.junit.Assert
1821
import org.junit.Rule
1922
import org.junit.Test
2023
import org.junit.runner.RunWith
2124
import org.pytorch.executorch.TensorImageUtils.bitmapToFloat32Tensor
22-
import java.io.File
23-
import java.io.IOException
24-
import java.net.URISyntaxException
2525

26-
/** Unit tests for [Module]. */
26+
/** Unit tests for [Module]. */
2727
@RunWith(AndroidJUnit4::class)
2828
class ModuleE2ETest {
2929
@get:Rule
@@ -46,7 +46,7 @@ class ModuleE2ETest {
4646
bitmapToFloat32Tensor(
4747
bitmap,
4848
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
49-
TensorImageUtils.TORCHVISION_NORM_STD_RGB
49+
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
5050
)
5151

5252
val module = Module.load(getTestFilePath(filePath))
@@ -69,7 +69,10 @@ class ModuleE2ETest {
6969

7070
val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"))
7171
val expectedBackends = arrayOf("XnnpackBackend")
72-
Assert.assertArrayEquals(expectedBackends, module.getMethodMetadata("forward").getBackends())
72+
Assert.assertArrayEquals(
73+
expectedBackends,
74+
module.getMethodMetadata("forward").getBackends(),
75+
)
7376
}
7477

7578
@Test
@@ -92,7 +95,10 @@ class ModuleE2ETest {
9295

9396
companion object {
9497
private fun getTestFilePath(fileName: String): String {
95-
return InstrumentationRegistry.getInstrumentation().targetContext.externalCacheDir.toString() + fileName
98+
return InstrumentationRegistry.getInstrumentation()
99+
.targetContext
100+
.externalCacheDir
101+
.toString() + fileName
96102
}
97103

98104
fun argmax(array: FloatArray): Int {

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@ import android.Manifest
1111
import androidx.test.InstrumentationRegistry
1212
import androidx.test.ext.junit.runners.AndroidJUnit4
1313
import androidx.test.rule.GrantPermissionRule
14-
import org.apache.commons.io.FileUtils
15-
import org.junit.Assert
16-
import org.junit.Before
17-
import org.junit.Rule
18-
import org.junit.Test
19-
import org.junit.runner.RunWith
2014
import java.io.File
2115
import java.io.IOException
2216
import java.net.URISyntaxException
2317
import java.util.concurrent.CountDownLatch
2418
import java.util.concurrent.TimeUnit
2519
import java.util.concurrent.atomic.AtomicInteger
20+
import org.apache.commons.io.FileUtils
21+
import org.junit.Assert
22+
import org.junit.Before
23+
import org.junit.Rule
24+
import org.junit.Test
25+
import org.junit.runner.RunWith
2626

27-
/** Unit tests for [Module]. */
27+
/** Unit tests for [Module]. */
2828
@RunWith(AndroidJUnit4::class)
2929
class ModuleInstrumentationTest {
3030
@Before
@@ -150,8 +150,7 @@ class ModuleInstrumentationTest {
150150
val results = module.forward()
151151
Assert.assertTrue(results[0].isTensor)
152152
completed.incrementAndGet()
153-
} catch (_: InterruptedException) {
154-
}
153+
} catch (_: InterruptedException) {}
155154
}
156155

157156
val threads = arrayOfNulls<Thread>(numThreads)
@@ -179,7 +178,10 @@ class ModuleInstrumentationTest {
179178
private const val ACCESS_FAILED = 0x22
180179

181180
private fun getTestFilePath(fileName: String): String {
182-
return InstrumentationRegistry.getInstrumentation().targetContext.externalCacheDir.toString() + fileName
181+
return InstrumentationRegistry.getInstrumentation()
182+
.targetContext
183+
.externalCacheDir
184+
.toString() + fileName
183185
}
184186
}
185187
}

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,41 @@ import java.nio.FloatBuffer
1616
* [android.media.Image] source.
1717
*/
1818
object TensorImageUtils {
19-
@JvmField
20-
var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f)
21-
@JvmField
22-
var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f)
19+
@JvmField var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f)
20+
21+
@JvmField var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f)
2322

2423
/**
25-
* Creates new [Tensor] from full [android.graphics.Bitmap], normalized with specified
26-
* in parameters mean and std.
24+
* Creates new [Tensor] from full [android.graphics.Bitmap], normalized with specified in
25+
* parameters mean and std.
2726
*
2827
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
2928
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
30-
* order
29+
* order
3130
*/
3231
@JvmStatic
3332
fun bitmapToFloat32Tensor(
34-
bitmap: Bitmap, normMeanRGB: FloatArray, normStdRGB: FloatArray
33+
bitmap: Bitmap,
34+
normMeanRGB: FloatArray,
35+
normStdRGB: FloatArray,
3536
): Tensor {
3637
checkNormMeanArg(normMeanRGB)
3738
checkNormStdArg(normStdRGB)
3839

3940
return bitmapToFloat32Tensor(
40-
bitmap, 0, 0, bitmap.width, bitmap.height, normMeanRGB, normStdRGB
41+
bitmap,
42+
0,
43+
0,
44+
bitmap.width,
45+
bitmap.height,
46+
normMeanRGB,
47+
normStdRGB,
4148
)
4249
}
4350

4451
/**
45-
* Writes tensor content from specified [android.graphics.Bitmap], normalized with specified
46-
* in parameters mean and std to specified [java.nio.FloatBuffer] with specified offset.
52+
* Writes tensor content from specified [android.graphics.Bitmap], normalized with specified in
53+
* parameters mean and std to specified [java.nio.FloatBuffer] with specified offset.
4754
*
4855
* @param bitmap [android.graphics.Bitmap] as a source for Tensor data
4956
* @param x - x coordinate of top left corner of bitmap's area
@@ -52,7 +59,7 @@ object TensorImageUtils {
5259
* @param height - height of bitmap's area
5360
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
5461
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
55-
* order
62+
* order
5663
*/
5764
fun bitmapToFloatBuffer(
5865
bitmap: Bitmap,
@@ -63,7 +70,7 @@ object TensorImageUtils {
6370
normMeanRGB: FloatArray,
6471
normStdRGB: FloatArray,
6572
outBuffer: FloatBuffer,
66-
outBufferOffset: Int
73+
outBufferOffset: Int,
6774
) {
6875
checkOutBufferCapacity(outBuffer, outBufferOffset, width, height)
6976
checkNormMeanArg(normMeanRGB)
@@ -88,8 +95,8 @@ object TensorImageUtils {
8895
}
8996

9097
/**
91-
* Creates new [Tensor] from specified area of [android.graphics.Bitmap], normalized
92-
* with specified in parameters mean and std.
98+
* Creates new [Tensor] from specified area of [android.graphics.Bitmap], normalized with
99+
* specified in parameters mean and std.
93100
*
94101
* @param bitmap [android.graphics.Bitmap] as a source for Tensor data
95102
* @param x - x coordinate of top left corner of bitmap's area
@@ -98,7 +105,7 @@ object TensorImageUtils {
98105
* @param height - height of bitmap's area
99106
* @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order
100107
* @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB
101-
* order
108+
* order
102109
*/
103110
fun bitmapToFloat32Tensor(
104111
bitmap: Bitmap,
@@ -107,7 +114,7 @@ object TensorImageUtils {
107114
width: Int,
108115
height: Int,
109116
normMeanRGB: FloatArray,
110-
normStdRGB: FloatArray
117+
normStdRGB: FloatArray,
111118
): Tensor {
112119
checkNormMeanArg(normMeanRGB)
113120
checkNormStdArg(normStdRGB)
@@ -118,17 +125,31 @@ object TensorImageUtils {
118125
}
119126

120127
private fun checkOutBufferCapacity(
121-
outBuffer: FloatBuffer, outBufferOffset: Int, tensorWidth: Int, tensorHeight: Int
128+
outBuffer: FloatBuffer,
129+
outBufferOffset: Int,
130+
tensorWidth: Int,
131+
tensorHeight: Int,
122132
) {
123-
check(outBufferOffset + 3 * tensorWidth * tensorHeight <= outBuffer.capacity()) { "Buffer underflow" }
133+
check(outBufferOffset + 3 * tensorWidth * tensorHeight <= outBuffer.capacity()) {
134+
"Buffer underflow"
135+
}
124136
}
125137

126138
private fun checkTensorSize(tensorWidth: Int, tensorHeight: Int) {
127-
require(!(tensorHeight <= 0 || tensorWidth <= 0)) { "tensorHeight and tensorWidth must be positive" }
139+
require(!(tensorHeight <= 0 || tensorWidth <= 0)) {
140+
"tensorHeight and tensorWidth must be positive"
141+
}
128142
}
129143

130144
private fun checkRotateCWDegrees(rotateCWDegrees: Int) {
131-
require(!(rotateCWDegrees != 0 && rotateCWDegrees != 90 && rotateCWDegrees != 180 && rotateCWDegrees != 270)) { "rotateCWDegrees must be one of 0, 90, 180, 270" }
145+
require(
146+
!(rotateCWDegrees != 0 &&
147+
rotateCWDegrees != 90 &&
148+
rotateCWDegrees != 180 &&
149+
rotateCWDegrees != 270)
150+
) {
151+
"rotateCWDegrees must be one of 0, 90, 180, 270"
152+
}
132153
}
133154

134155
private fun checkNormStdArg(normStdRGB: FloatArray) {

0 commit comments

Comments
 (0)