android_mtdutils: Add mtd_write_descriptor

This CL adds "mtd_write_descriptor" to create a write context from a
file descriptor. It also refactors the code, especially
mtd_write_partition and write_block to alleviate the need of an
MtdPartition.

BUG=chromium:426742
TEST=emerge android_mtdutils

Change-Id: I450f21ab6ae3dd774cb531f9b166dd1864b83d05
Reviewed-on: https://chromium-review.googlesource.com/234457
Reviewed-by: Mike Frysinger <vapier@chromium.org>
Commit-Queue: Nam Nguyen <namnguyen@chromium.org>
Trybot-Ready: Nam Nguyen <namnguyen@chromium.org>
Tested-by: Nam Nguyen <namnguyen@chromium.org>
diff --git a/mtdutils/mtdutils.c b/mtdutils/mtdutils.c
index 9598c35..4e0463b 100644
--- a/mtdutils/mtdutils.c
+++ b/mtdutils/mtdutils.c
@@ -48,10 +48,13 @@
 };
 
 struct MtdWriteContext {
-    const MtdPartition *partition;
+    uint64_t size;
+    uint32_t erase_size;
     char *buffer;
     size_t stored;
     int fd;
+    /* true if |fd| was opened by write_partition, and should be closed. */
+    bool owned_fd;
 
     off_t* bad_block_offsets;
     int bad_block_alloc;
@@ -455,32 +458,45 @@
     free(ctx);
 }
 
-MtdWriteContext *mtd_write_partition(const MtdPartition *partition)
+MtdWriteContext *mtd_write_descriptor(int fd, const char *dev_node_name)
 {
     MtdWriteContext *ctx = (MtdWriteContext*) malloc(sizeof(MtdWriteContext));
     if (ctx == NULL) return NULL;
 
+    if (mtd_node_info(dev_node_name, &ctx->size, &ctx->erase_size, NULL) != 0) {
+        free(ctx);
+        return NULL;
+    }
+
     ctx->bad_block_offsets = NULL;
     ctx->bad_block_alloc = 0;
     ctx->bad_block_count = 0;
 
-    ctx->buffer = malloc(partition->erase_size);
+    ctx->buffer = malloc(ctx->erase_size);
     if (ctx->buffer == NULL) {
         free(ctx);
         return NULL;
     }
 
+    ctx->fd = fd;
+    ctx->owned_fd = false;
+    ctx->stored = 0;
+    return ctx;
+}
+
+MtdWriteContext *mtd_write_partition(const MtdPartition *partition)
+{
     char mtddevname[32];
     sprintf(mtddevname, "/dev/mtd/mtd%d", partition->device_index);
-    ctx->fd = open(mtddevname, O_RDWR);
-    if (ctx->fd < 0) {
-        free(ctx->buffer);
-        free(ctx);
+    int fd = open(mtddevname, O_RDWR | O_CLOEXEC);
+    if (fd < 0) {
         return NULL;
     }
 
-    ctx->partition = partition;
-    ctx->stored = 0;
+    MtdWriteContext *ctx = mtd_write_descriptor(fd, mtddevname);
+    if (ctx) {
+        ctx->owned_fd = true;
+    }
     return ctx;
 }
 
@@ -495,14 +511,13 @@
 
 static int write_block(MtdWriteContext *ctx, const char *data)
 {
-    const MtdPartition *partition = ctx->partition;
     int fd = ctx->fd;
 
     off_t pos = lseek(fd, 0, SEEK_CUR);
     if (pos == (off_t) -1) return 1;
 
-    ssize_t size = partition->erase_size;
-    while (pos + size <= (int) partition->size) {
+    ssize_t size = ctx->erase_size;
+    while (pos + size <= ctx->size) {
         loff_t bpos = pos;
         int ret = ioctl(fd, MEMGETBADBLOCK, &bpos);
         if (ret != 0 && !(ret == -1 && errno == EOPNOTSUPP)) {
@@ -510,7 +525,7 @@
             fprintf(stderr,
                     "mtd: not writing bad block at 0x%08lx (ret %d errno %d)\n",
                     pos, ret, errno);
-            pos += partition->erase_size;
+            pos += ctx->erase_size;
             continue;  // Don't try to erase known factory-bad blocks.
         }
 
@@ -554,7 +569,7 @@
         add_bad_block_offset(ctx, pos);
         printf("mtd: skipping write block at 0x%08lx\n", pos);
         ioctl(fd, MEMERASE, &erase_info);
-        pos += partition->erase_size;
+        pos += ctx->erase_size;
     }
 
     // Ran out of space on the device
@@ -567,8 +582,8 @@
     size_t wrote = 0;
     while (wrote < len) {
         // Coalesce partial writes into complete blocks
-        if (ctx->stored > 0 || len - wrote < ctx->partition->erase_size) {
-            size_t avail = ctx->partition->erase_size - ctx->stored;
+        if (ctx->stored > 0 || len - wrote < ctx->erase_size) {
+            size_t avail = ctx->erase_size - ctx->stored;
             size_t copy = len - wrote < avail ? len - wrote : avail;
             memcpy(ctx->buffer + ctx->stored, data + wrote, copy);
             ctx->stored += copy;
@@ -576,15 +591,15 @@
         }
 
         // If a complete block was accumulated, write it
-        if (ctx->stored == ctx->partition->erase_size) {
+        if (ctx->stored == ctx->erase_size) {
             if (write_block(ctx, ctx->buffer)) return -1;
             ctx->stored = 0;
         }
 
         // Write complete blocks directly from the user's buffer
-        while (ctx->stored == 0 && len - wrote >= ctx->partition->erase_size) {
+        while (ctx->stored == 0 && len - wrote >= ctx->erase_size) {
             if (write_block(ctx, data + wrote)) return -1;
-            wrote += ctx->partition->erase_size;
+            wrote += ctx->erase_size;
         }
     }
 
@@ -595,7 +610,7 @@
 {
     // Zero-pad and write any pending data to get us to a block boundary
     if (ctx->stored > 0) {
-        size_t zero = ctx->partition->erase_size - ctx->stored;
+        size_t zero = ctx->erase_size - ctx->stored;
         memset(ctx->buffer + ctx->stored, 0, zero);
         if (write_block(ctx, ctx->buffer)) return -1;
         ctx->stored = 0;
@@ -604,7 +619,7 @@
     off_t pos = lseek(ctx->fd, 0, SEEK_CUR);
     if ((off_t) pos == (off_t) -1) return pos;
 
-    const int total = (ctx->partition->size - pos) / ctx->partition->erase_size;
+    const int total = (ctx->size - pos) / ctx->erase_size;
     if (blocks < 0) blocks = total;
     if (blocks > total) {
         errno = ENOSPC;
@@ -616,17 +631,17 @@
         loff_t bpos = pos;
         if (ioctl(ctx->fd, MEMGETBADBLOCK, &bpos) > 0) {
             printf("mtd: not erasing bad block at 0x%08lx\n", pos);
-            pos += ctx->partition->erase_size;
+            pos += ctx->erase_size;
             continue;  // Don't try to erase known factory-bad blocks.
         }
 
         struct erase_info_user erase_info;
         erase_info.start = pos;
-        erase_info.length = ctx->partition->erase_size;
+        erase_info.length = ctx->erase_size;
         if (ioctl(ctx->fd, MEMERASE, &erase_info) < 0) {
             printf("mtd: erase failure at 0x%08lx\n", pos);
         }
-        pos += ctx->partition->erase_size;
+        pos += ctx->erase_size;
     }
 
     return pos;
@@ -637,7 +652,7 @@
     int r = 0;
     // Make sure any pending data gets written
     if (mtd_erase_blocks(ctx, 0) == (off_t) -1) r = -1;
-    if (close(ctx->fd)) r = -1;
+    if (ctx->owned_fd && close(ctx->fd)) r = -1;
     free(ctx->bad_block_offsets);
     free(ctx->buffer);
     free(ctx);
@@ -651,7 +666,7 @@
     int i;
     for (i = 0; i < ctx->bad_block_count; ++i) {
         if (ctx->bad_block_offsets[i] == pos) {
-            pos += ctx->partition->erase_size;
+            pos += ctx->erase_size;
         } else if (ctx->bad_block_offsets[i] > pos) {
             return pos;
         }
diff --git a/mtdutils/mtdutils.h b/mtdutils/mtdutils.h
index c8a859a..77c55d6 100644
--- a/mtdutils/mtdutils.h
+++ b/mtdutils/mtdutils.h
@@ -64,7 +64,10 @@
 /* NOTICE: Only use with reads of multiple erase blocks please. */
 off64_t mtd_read_skip_to(const MtdReadContext *, off64_t offset);
 
+/* write_partition opens the device node, and owns the file descriptor. */
 MtdWriteContext *mtd_write_partition(const MtdPartition *);
+/* write_descriptor uses the provided file descriptor. */
+MtdWriteContext *mtd_write_descriptor(int fd, const char *dev_node_name);
 ssize_t mtd_write_data(MtdWriteContext *, const char *data, size_t data_len);
 off_t mtd_erase_blocks(MtdWriteContext *, int blocks);  /* 0 ok, -1 for all */
 off_t mtd_find_write_start(MtdWriteContext *ctx, off_t pos);