diff options
| author | Joel Klinghed <the_jk@spawned.biz> | 2024-08-23 00:24:18 +0200 |
|---|---|---|
| committer | Joel Klinghed <the_jk@spawned.biz> | 2024-08-23 00:24:18 +0200 |
| commit | 16d2b1750a78527a5524bfc3171a42f67c323508 (patch) | |
| tree | 1defa7bb7f44a089103582b0777662394ef1ba56 | |
| parent | a95264b5273748330c3126632277fd7a0db8ec91 (diff) | |
samba: Add support for read/write file
6 files changed, 355 insertions, 3 deletions
diff --git a/libs/samba/src/main/cpp/jni.hpp b/libs/samba/src/main/cpp/jni.hpp index 3b5078a..1729828 100644 --- a/libs/samba/src/main/cpp/jni.hpp +++ b/libs/samba/src/main/cpp/jni.hpp @@ -51,7 +51,7 @@ class Ref { T ptr_; }; -constexpr jint JNI_VERSION = JNI_VERSION_1_2; +constexpr jint JNI_VERSION = JNI_VERSION_1_4; JNIEnv* AttachCurrentThread(); diff --git a/libs/samba/src/main/cpp/samba.cpp b/libs/samba/src/main/cpp/samba.cpp index 1e6b425..4a8af51 100644 --- a/libs/samba/src/main/cpp/samba.cpp +++ b/libs/samba/src/main/cpp/samba.cpp @@ -1,7 +1,9 @@ #include <algorithm> #include <cassert> +#include <fcntl.h> #include <jni.h> #include <memory> +#include <limits> #include <optional> #include <string> #include <string_view> @@ -46,6 +48,74 @@ class Dir { smb2dir* const dir_; }; +class File { + public: + File(std::shared_ptr<smb2_context> context, smb2fh* fh) : context_(std::move(context)), fh_(fh) { + assert(context_ && fh_); + } + + ~File() { + smb2_close(context_.get(), fh_); + } + + File(const File&) = delete; + File& operator=(const File&) = delete; + + int32_t read(uint8_t* data, int32_t size) { + if (size <= 0) return 0; + int32_t total = 0; + while (size > std::numeric_limits<int>::max()) { + int ret = smb2_read(context_.get(), fh_, data, + std::numeric_limits<int>::max()); + if (ret < 0) return total ? total : ret; + if (ret != std::numeric_limits<int>::max()) + return total + ret; + data += ret; + size -= ret; + } + int ret = smb2_read(context_.get(), fh_, data, size); + if (ret < 0) return total ? total : ret; + return total + ret; + } + + int32_t write(const uint8_t* data, int32_t size) { + if (size <= 0) return 0; + int32_t total = 0; + while (size > std::numeric_limits<int>::max()) { + int ret = smb2_write(context_.get(), fh_, data, + std::numeric_limits<int>::max()); + if (ret < 0) return total ? total : ret; + if (ret != std::numeric_limits<int>::max()) + return total + ret; + data += ret; + size -= ret; + } + int ret = smb2_write(context_.get(), fh_, data, size); + if (ret < 0) return total ? total : ret; + return total + ret; + } + + int64_t seek(int64_t offset, int32_t native_whence) { + int whence; + switch (native_whence) { + case 0: + whence = SEEK_SET; + break; + case 1: + whence = SEEK_CUR; + break; + default: + assert(false); + return -1; + } + return smb2_lseek(context_.get(), fh_, offset, whence, nullptr); + } + + private: + std::shared_ptr<smb2_context> context_; + smb2fh* const fh_; +}; + class Url { public: explicit Url(smb2_url* url) : url_(url) { @@ -142,6 +212,23 @@ class Context { } } + [[nodiscard]] std::unique_ptr<File> OpenFile(const std::string& path, int32_t mode) { + int flags; + switch (mode) { + case 0: + flags = O_RDONLY; + break; + case 1: + flags = O_WRONLY | O_CREAT | O_TRUNC; + break; + default: + assert(false); + return nullptr; + } + auto* ptr = smb2_open(context_.get(), path.c_str(), flags); + return ptr ? std::make_unique<File>(context_, ptr) : nullptr; + } + private: struct ContextDeleter { void operator()(smb2_context* context) { @@ -205,6 +292,10 @@ jstring nativeContextReadLink(JNIEnv* env, jclass clazz, jlong ptr, jstring path return nullptr; } +jlong nativeContextOpenFile(JNIEnv* env, jclass clazz, jlong ptr, jstring path, jint mode) { + return reinterpret_cast<jlong>(reinterpret_cast<Context*>(ptr)->OpenFile(jni::StringToUTF8(env, jni::ParamRef<jstring>(env, path)), mode).release()); +} + void nativeUrlDestroy(JNIEnv* env, jclass clazz, jlong ptr) { delete reinterpret_cast<Url*>(ptr); } @@ -224,6 +315,50 @@ jobjectArray nativeDirList(JNIEnv* env, jclass clazz, jlong ptr) { return reinterpret_cast<Dir*>(ptr)->List(env, g_DirEntryClass).release(); } +void nativeFileDestroy(JNIEnv* env, jclass clazz, jlong ptr) { + delete reinterpret_cast<File*>(ptr); +} + +jint nativeFileRead(JNIEnv* env, jclass clazz, jlong ptr, jbyteArray array, jint offset, jint length) { + jboolean is_copy = JNI_FALSE; + bool critical = true; + auto* data = reinterpret_cast<jbyte*>(env->GetPrimitiveArrayCritical(array, &is_copy)); + if (!data) { + critical = false; + data = env->GetByteArrayElements(array, &is_copy); + if (!data) return -1; + } + auto ret = reinterpret_cast<File *>(ptr)->read(reinterpret_cast<uint8_t *>(data + offset), length); + if (critical) { + env->ReleasePrimitiveArrayCritical(array, data, JNI_COMMIT); + } else { + env->ReleaseByteArrayElements(array, data, JNI_COMMIT); + } + return ret; +} + +jint nativeFileWrite(JNIEnv* env, jclass clazz, jlong ptr, jbyteArray array, jint offset, jint length) { + jboolean is_copy = JNI_FALSE; + bool critical = true; + auto* data = reinterpret_cast<jbyte*>(env->GetPrimitiveArrayCritical(array, &is_copy)); + if (!data) { + critical = false; + data = env->GetByteArrayElements(array, &is_copy); + if (!data) return -1; + } + auto ret = reinterpret_cast<File *>(ptr)->write(reinterpret_cast<uint8_t *>(data + offset), length); + if (critical) { + env->ReleasePrimitiveArrayCritical(array, data, JNI_ABORT); + } else { + env->ReleaseByteArrayElements(array, data, JNI_ABORT); + } + return ret; +} + +jlong nativeFileSeek(JNIEnv* env, jclass clazz, jlong ptr, jlong offset, jint whence) { + return reinterpret_cast<File*>(ptr)->seek(offset, whence); +} + jni::GlobalRef<jclass> g_NativeSambaClass(nullptr, nullptr); jmethodID g_CreateDirEntry; @@ -244,12 +379,18 @@ void RegisterSamba(JNIEnv* env) { { "nativeContextRemoveDir", "(JLjava/lang/String;)Z", reinterpret_cast<void*>(&nativeContextRemoveDir) }, { "nativeContextUnlink", "(JLjava/lang/String;)Z", reinterpret_cast<void*>(&nativeContextUnlink) }, { "nativeContextReadLink", "(JLjava/lang/String;)Ljava/lang/String;", reinterpret_cast<void*>(&nativeContextReadLink) }, + { "nativeContextOpenFile", "(JLjava/lang/String;I)J", reinterpret_cast<void*>(&nativeContextOpenFile) }, { "nativeUrlDestroy", "(J)V", reinterpret_cast<void*>(&nativeUrlDestroy) }, { "nativeUrlPath", "(J)Ljava/lang/String;", reinterpret_cast<void*>(&nativeUrlPath) }, { "nativeDirDestroy", "(J)V", reinterpret_cast<void*>(&nativeDirDestroy) }, { "nativeDirList", "(J)[Lorg/the_jk/cleversync/io/samba/NativeSamba$DirEntry;", reinterpret_cast<void*>(&nativeDirList) }, + + { "nativeFileDestroy", "(J)V", reinterpret_cast<void*>(&nativeFileDestroy) }, + { "nativeFileRead", "(J[BII)I", reinterpret_cast<void*>(&nativeFileRead) }, + { "nativeFileSeek", "(JJI)J", reinterpret_cast<void*>(&nativeFileSeek) }, + { "nativeFileWrite", "(J[BII)I", reinterpret_cast<void*>(&nativeFileWrite) }, }; auto ret = env->RegisterNatives(clazz.get(), methods, sizeof(methods) / sizeof(methods[0])); ABORT_IF_NOT_OK(ret); diff --git a/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/NativeSamba.kt b/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/NativeSamba.kt index 1863e22..c951d37 100644 --- a/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/NativeSamba.kt +++ b/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/NativeSamba.kt @@ -27,6 +27,7 @@ internal object NativeSamba { fun removeDir(path: String): Boolean fun unlink(path: String): Boolean fun readLink(path: String): String? + fun openFile(path: String, mode: OpenMode): File? } interface Url : Object { @@ -53,6 +54,25 @@ internal object NativeSamba { fun list(): Array<DirEntry> } + enum class OpenMode(val value: Int) { + READ(0), + WRITE_CREATE_TRUNCATE(1), + } + + enum class SeekWhence(val value: Int) { + SET(0), + CURRENT(1), + } + + interface File : Object { + val path: String + + fun read(bytes: ByteArray, offset: Int, length: Int): Int + fun seek(offset: Long, whence: SeekWhence): Long + + fun write(bytes: ByteArray, offset: Int, length: Int): Int + } + private class NativeContext(private var ptr: Long): Context { override fun destroy() { if (ptr == 0L) return @@ -97,6 +117,11 @@ internal object NativeSamba { override fun readLink(path: String): String? { return nativeContextReadLink(ptr, path) } + + override fun openFile(path: String, mode: NativeSamba.OpenMode): File? { + val file = nativeContextOpenFile(ptr, path, mode.value) + return if (file != 0L) NativeFile(path, file) else null + } } private class NativeUrl(private var ptr: Long): Url { @@ -123,6 +148,26 @@ internal object NativeSamba { } } + private class NativeFile(override val path: String, private var ptr: Long): File { + override fun read(bytes: ByteArray, offset: Int, length: Int): Int { + return nativeFileRead(ptr, bytes, offset, length) + } + + override fun seek(offset: Long, whence: SeekWhence): Long { + return nativeFileSeek(ptr, offset, whence.value) + } + + override fun write(bytes: ByteArray, offset: Int, length: Int): Int { + return nativeFileWrite(ptr, bytes, offset, length) + } + + override fun destroy() { + if (ptr == 0L) return + nativeFileDestroy(ptr) + ptr = 0L + } + } + init { System.loadLibrary("samba") } @@ -154,10 +199,16 @@ internal object NativeSamba { private external fun nativeContextRemoveDir(ptr: Long, path: String): Boolean private external fun nativeContextUnlink(ptr: Long, path: String): Boolean private external fun nativeContextReadLink(otr: Long, path: String): String? + private external fun nativeContextOpenFile(ptr: Long, path: String, mode: Int): Long private external fun nativeUrlDestroy(ptr: Long) private external fun nativeUrlPath(ptr: Long): String private external fun nativeDirDestroy(ptr: Long) private external fun nativeDirList(ptr: Long): Array<DirEntry> + + private external fun nativeFileDestroy(ptr: Long) + private external fun nativeFileRead(ptr: Long, bytes: ByteArray, offset: Int, length: Int): Int + private external fun nativeFileSeek(ptr: Long, offset: Long, whence: Int): Long + private external fun nativeFileWrite(ptr: Long, bytes: ByteArray, offset: Int, length: Int): Int } diff --git a/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/SambaConnection.kt b/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/SambaConnection.kt index d7302cd..6eda092 100644 --- a/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/SambaConnection.kt +++ b/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/SambaConnection.kt @@ -40,6 +40,9 @@ internal class SambaConnection(uri: String, credentials: SambaCredentials) { fun readLink(path: String): String? = if (connected) context.readLink(join(url!!.path(), path)) else null + fun openFile(path: String, read: NativeSamba.OpenMode): NativeSamba.File? = + if (connected) context.openFile(join(url!!.path(), path), read) else null + companion object { fun join(a: String, b: String): String { if (a.isEmpty() || b.startsWith("/")) return b diff --git a/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/SambaFile.kt b/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/SambaFile.kt index 817c1bf..9913035 100644 --- a/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/SambaFile.kt +++ b/libs/samba/src/main/java/org/the_jk/cleversync/io/samba/SambaFile.kt @@ -15,11 +15,96 @@ internal class SambaFile( private var cacheEndOfLife: Instant = Instant.now().plusSeconds(60), ) : ModifiableFile { override fun write(): OutputStream { - TODO("Not yet implemented") + val file = conn.openFile(path, NativeSamba.OpenMode.WRITE_CREATE_TRUNCATE) + ?: throw IOException(conn.error) + return object : OutputStream() { + override fun write(b: Int) { + val buffer = ByteArray(1) + buffer[0] = b.toByte() + if (file.write(buffer, 0, 1) != 1) throw IOException(conn.error) + } + + override fun write(b: ByteArray?) = write(b, 0, b?.size ?: 0) + + override fun write(b: ByteArray?, off: Int, len: Int) { + if (b == null) throw NullPointerException("b == null") + if (off < 0) throw IndexOutOfBoundsException("off < 0") + if (len < 0) throw java.lang.IndexOutOfBoundsException("len < 0") + if (off + len > b.size) throw IndexOutOfBoundsException("off + len > b.size") + if (file.write(b, off, len) != len) throw IOException(conn.error) + } + + override fun flush() { + clearCache() + } + + override fun close() { + file.destroy() + clearCache() + } + } } override fun read(): InputStream { - TODO("Not yet implemented") + val file = conn.openFile(path, NativeSamba.OpenMode.READ) ?: throw IOException(conn.error) + return object : InputStream() { + private var markedPosition = 0L + + override fun read(): Int { + val buffer = ByteArray(1) + val got = file.read(buffer, 0, 1) + if (got == 0) return -1 + if (got < 0) throw IOException(conn.error) + return buffer[0].toInt() + } + + override fun read(b: ByteArray?) = read(b, 0, b?.size ?: 0) + + override fun read(b: ByteArray?, off: Int, len: Int): Int { + if (b == null) throw NullPointerException("b == null") + if (off < 0) throw IndexOutOfBoundsException("off < 0") + if (len < 0) throw java.lang.IndexOutOfBoundsException("len < 0") + if (off + len > b.size) throw IndexOutOfBoundsException("off + len > b.size") + if (len == 0) return 0 + val got = file.read(b, off, len) + if (got == 0) return -1 + if (got < 0) throw IOException(conn.error) + return got + } + + override fun skip(n: Long): Long { + if (n <= 0) return 0 + val offset = file.seek(n, NativeSamba.SeekWhence.CURRENT) + if (offset < 0) throw IOException(conn.error) + return offset - n + } + + override fun available(): Int { + val current = file.seek(0, NativeSamba.SeekWhence.CURRENT) + if (current < 0) throw IOException(conn.error) + val total = size + if (current.toULong() >= total) return 0 + val left = total - current.toULong() + if (left >= Int.MAX_VALUE.toULong()) return Int.MAX_VALUE + return left.toInt() + } + + override fun close() { + file.destroy() + } + + override fun mark(readlimit: Int) { + markedPosition = file.seek(0, NativeSamba.SeekWhence.CURRENT) + } + + override fun markSupported() = true + + override fun reset() { + if (file.seek(markedPosition, NativeSamba.SeekWhence.SET) != markedPosition) { + throw IOException(conn.error) + } + } + } } override val size: ULong get() { @@ -37,4 +122,8 @@ internal class SambaFile( private fun useCached(): Boolean { return Instant.now().isBefore(cacheEndOfLife) } + + private fun clearCache() { + cacheEndOfLife = Instant.EPOCH + } } diff --git a/libs/samba/src/test/java/org/the_jk/cleversync/samba/SambaTreeTest.kt b/libs/samba/src/test/java/org/the_jk/cleversync/samba/SambaTreeTest.kt index 1eddb6b..5df61da 100644 --- a/libs/samba/src/test/java/org/the_jk/cleversync/samba/SambaTreeTest.kt +++ b/libs/samba/src/test/java/org/the_jk/cleversync/samba/SambaTreeTest.kt @@ -14,6 +14,7 @@ import org.robolectric.annotation.Config import org.the_jk.cleversync.io.Link import org.the_jk.cleversync.io.samba.SambaCredentials import java.io.File +import java.nio.charset.StandardCharsets import java.nio.file.Files import kotlin.io.path.createSymbolicLinkPointingTo @@ -82,6 +83,73 @@ class SambaTreeTest { } } + @Test + fun readFile() { + File(shareDir, "file").writeText("hello world") + + SambaTreeFactory.tree(uri, credentials).getOrThrow().use { root -> + val file = root.openFile("file") + assertThat(file?.name).isEqualTo("file") + assertThat(file?.size).isEqualTo(11UL) + + file?.read().use { input -> + assertThat(input?.readAllBytes()?.toString(StandardCharsets.UTF_8)).isEqualTo("hello world") + } + + file?.read().use { input -> + assertThat(input?.available()).isEqualTo(11) + assertThat(input?.markSupported()).isTrue() + val buffer = ByteArray(10) + assertThat(input?.read(buffer, 5, 5)).isEqualTo(5) + input?.mark(100) + assertThat(buffer.sliceArray(5..<10).toString(StandardCharsets.UTF_8)).isEqualTo("hello") + assertThat(input?.read(buffer)).isEqualTo(6) + assertThat(buffer.sliceArray(0..<6).toString(StandardCharsets.UTF_8)).isEqualTo(" world") + input?.reset() + assertThat(input?.read(buffer, 3, 5)).isEqualTo(5) + assertThat(buffer.sliceArray(3..<8).toString(StandardCharsets.UTF_8)).isEqualTo(" worl") + } + } + } + + @Test + fun writeFile() { + SambaTreeFactory.modifiableTree(uri, credentials).getOrThrow().use { root -> + val file = root.createFile("file") + assertThat(file.name).isEqualTo("file") + + file.write().writer().use { output -> + output.write("hello world") + } + + assertThat(file.size).isEqualTo(11UL) + } + + assertThat(File(shareDir, "file").readText()).isEqualTo("hello world") + } + + @Test + fun overwriteFile() { + File(shareDir, "file").writeText("hello world") + + SambaTreeFactory.modifiableTree(uri, credentials).getOrThrow().use { root -> + val file = root.modifiableOpenFile("file") + assertThat(file?.name).isEqualTo("file") + assertThat(file?.size).isEqualTo(11UL) + + file?.write().use { output -> + val buffer = "foobar".toByteArray(StandardCharsets.UTF_8) + output?.write(buffer, 0, 1) + output?.write(buffer, 1, 2) + output?.write(buffer, 3, 3) + } + + assertThat(file?.size).isEqualTo(6UL) + } + + assertThat(File(shareDir, "file").readText()).isEqualTo("foobar") + } + companion object { private lateinit var uri: String private lateinit var credentials: SambaCredentials |
