使用dataframes和jdbc更新mysql表

dzjeubhm  于 2021-06-20  发布在  Mysql
关注(0)|答案(6)|浏览(370)

我正在尝试使用sparksqldataframes和jdbc连接在mysql上插入和更新一些数据。
我使用savemode.append成功地插入了新数据。有没有办法从sparksql更新mysql表中已有的数据?
我要插入的代码是: myDataFrame.write.mode(SaveMode.Append).jdbc(JDBCurl,mySqlTable,connectionProperties) 如果我改为savemode.overwrite,它会删除完整的表并创建一个新表,我会在mysql中寻找类似“on duplicate key update”的内容

z9smfwbn

z9smfwbn1#

在pyspark中,我无法做到这一点,所以我决定使用odbc。

url = "jdbc:sqlserver://xxx:1433;databaseName=xxx;user=xxx;password=xxx"
df.write.jdbc(url=url, table="__TableInsert", mode='overwrite')
cnxn  = pyodbc.connect('Driver={ODBC Driver 17 for SQL Server};Server=xxx;Database=xxx;Uid=xxx;Pwd=xxx;', autocommit=False) 
try:
    crsr = cnxn.cursor()
    # DO UPSERTS OR WHATEVER YOU WANT
    crsr.execute("DELETE FROM Table")
    crsr.execute("INSERT INTO Table (Field) SELECT Field FROM __TableInsert")
    cnxn.commit()
except:
    cnxn.rollback()
cnxn.close()
ttygqcqt

ttygqcqt2#

zero323的答案是正确的,我只想补充一点,您可以使用jaydebeapi包来解决这个问题:https://pypi.python.org/pypi/jaydebeapi/
更新mysql表中的数据。由于您已经安装了mysql jdbc驱动程序,所以这可能是一个悬而未决的成果。
jaydebeapi模块允许您使用javajdbc从python代码连接到数据库。它为该数据库提供了pythondbapiv2.0。
我们使用python的anaconda发行版,jaydebeapi python包是标准的。
参见上面链接中的示例。

ss2ws0br

ss2ws0br3#

可惜没有 SaveMode.Upsert 在Spark模式这类相当常见的情况下,如上升。
zero322总体上是正确的,但我认为应该有可能(在性能上有所妥协)提供这样的替换特性。
我还想为这个案例提供一些java代码。当然,它的性能不如spark的内置性能,但它应该是满足您需求的良好基础。只需根据您的需要进行修改:

myDF.repartition(20); //one connection per partition, see below

myDF.foreachPartition((Iterator<Row> t) -> {
            Connection conn = DriverManager.getConnection(
                    Constants.DB_JDBC_CONN,
                    Constants.DB_JDBC_USER,
                    Constants.DB_JDBC_PASS);

            conn.setAutoCommit(true);
            Statement statement = conn.createStatement();

            final int batchSize = 100000;
            int i = 0;
            while (t.hasNext()) {
                Row row = t.next();
                try {
                    // better than REPLACE INTO, less cycles
                    statement.addBatch(("INSERT INTO mytable " + "VALUES ("
                            + "'" + row.getAs("_id") + "', 
                            + "'" + row.getStruct(1).get(0) + "'
                            + "')  ON DUPLICATE KEY UPDATE _id='" + row.getAs("_id") + "';"));
                    //conn.commit();

                    if (++i % batchSize == 0) {
                        statement.executeBatch();
                    }
                } catch (SQLIntegrityConstraintViolationException e) {
                    //should not occur, nevertheless
                    //conn.commit();
                } catch (SQLException e) {
                    e.printStackTrace();
                } finally {
                    //conn.commit();
                    statement.executeBatch();
                }
            }
            int[] ret = statement.executeBatch();

            System.out.println("Ret val: " + Arrays.toString(ret));
            System.out.println("Update count: " + statement.getUpdateCount());
            conn.commit();

            statement.close();
            conn.close();
deyfvvtc

deyfvvtc4#

如果您的表很小,那么您可以读取sql数据并在spark dataframe中进行升级。并覆盖现有的sql表。

ttcibm8c

ttcibm8c5#

这是不可能的。至于现在(spark 1.6.0/2.2.0快照)spark DataFrameWriter 仅支持四种写入模式: SaveMode.Overwrite :覆盖现有数据。 SaveMode.Append :追加数据。 SaveMode.Ignore :忽略操作(即无操作)。 SaveMode.ErrorIfExists :default选项,在运行时引发异常。
您可以手动插入,例如使用 mapPartitions (因为您希望upsert操作应该是幂等的并且易于实现),所以可以写入临时表并手动执行upsert,或者使用触发器。
一般来说,为批处理操作实现upsert行为并保持良好的性能绝非易事。您必须记住,一般情况下会有多个并发事务(每个分区一个事务),因此您必须确保没有写冲突(通常使用特定于应用程序的分区)或提供适当的恢复过程。在实践中,对临时表执行批写操作并直接在数据库中解析upsert部分可能更好。

35g0bw71

35g0bw716#

覆盖 org.apache.spark.sql.execution.datasources.jdbc JdbcUtils.scala insert intoreplace into ```
import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, SQLException}

import scala.collection.JavaConverters._
import scala.util.control.NonFatal
import com.typesafe.scalalogging.Logger
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, DriverWrapper, JDBCOptions}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}

/**

  • Util functions for JDBC tables.
    */
    object UpdateJdbcUtils {

val logger = Logger(this.getClass)

/**
* Returns a factory for creating connections to the given JDBC URL.
*
* @param options - JDBC options that contains url, table and other information.
*/
def createConnectionFactory(options: JDBCOptions): () => Connection = {
val driverClass: String = options.driverClass
() => {
DriverRegistry.register(driverClass)
val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
case d if d.getClass.getCanonicalName == driverClass => d
}.getOrElse {
throw new IllegalStateException(
s"Did not find registered driver with class $driverClass")
}
driver.connect(options.url, options.asConnectionProperties)
}
}

/**
* Returns a PreparedStatement that inserts a row into table via conn.
*/
def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
: PreparedStatement = {
val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
val sql = s"REPLACE INTO $table ($columns) VALUES ($placeholders)"
conn.prepareStatement(sql)
}

/**
* Retrieve standard jdbc types.
*
* @param dt The datatype (e.g. org.apache.spark.sql.types.StringType)
* @return The default JdbcType for this DataType
*/
def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
dt match {
case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
case t: DecimalType => Option(
JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
case _ => None
}
}

private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}

// A JDBCValueGetter is responsible for getting a value from ResultSet into a field
// for MutableRow. The last argument Int means the index for the value to be set in
// the row and also used for the value in ResultSet.
private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit

// A JDBCValueSetter is responsible for setting a value from Row into a field for
// PreparedStatement. The last argument Int means the index for the value to be set
// in the SQL statement and also used for the value in Row.
private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

/**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction (unless isolation level is "NONE")
* in order to avoid repeatedly inserting data as much as possible.
*
* It is still theoretically possible for rows in a DataFrame to be
* inserted into the database more than once if a stage somehow fails after
* the commit occurs but before the stage can return successfully.
*
* This is not a closure inside saveTable() because apparently cosmetic
* implementation changes elsewhere might easily render such a closure
* non-Serializable. Instead, we explicitly close over all variables that
* are used.
*/
def savePartition(
getConnection: () => Connection,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int): Iterator[Byte] = {
val conn = getConnection()
var committed = false

var finalIsolationLevel = Connection.TRANSACTION_NONE
if (isolationLevel != Connection.TRANSACTION_NONE) {
  try {
    val metadata = conn.getMetaData
    if (metadata.supportsTransactions()) {
      // Update to at least use the default isolation, if any transaction level
      // has been chosen and transactions are supported
      val defaultIsolation = metadata.getDefaultTransactionIsolation
      finalIsolationLevel = defaultIsolation
      if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
        // Finally update to actually requested level if possible
        finalIsolationLevel = isolationLevel
      } else {
        logger.warn(s"Requested isolation level $isolationLevel is not supported; " +
          s"falling back to default isolation level $defaultIsolation")
      }
    } else {
      logger.warn(s"Requested isolation level $isolationLevel, but transactions are unsupported")
    }
  } catch {
    case NonFatal(e) => logger.warn("Exception while detecting transaction support", e)
  }
}
val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

try {
  if (supportsTransactions) {
    conn.setAutoCommit(false) // Everything in the same db transaction.
    conn.setTransactionIsolation(finalIsolationLevel)
  }
  val stmt = insertStatement(conn, table, rddSchema, dialect)
  val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
    .map(makeSetter(conn, dialect, _))
  val numFields = rddSchema.fields.length

  try {
    var rowCount = 0
    while (iterator.hasNext) {
      val row = iterator.next()
      var i = 0
      while (i < numFields) {
        if (row.isNullAt(i)) {
          stmt.setNull(i + 1, nullTypes(i))
        } else {
          setters(i).apply(stmt, row, i)
        }
        i = i + 1
      }
      stmt.addBatch()
      rowCount += 1
      if (rowCount % batchSize == 0) {
        stmt.executeBatch()
        rowCount = 0
      }
    }
    if (rowCount > 0) {
      stmt.executeBatch()
    }
  } finally {
    stmt.close()
  }
  if (supportsTransactions) {
    conn.commit()
  }
  committed = true
  Iterator.empty
} catch {
  case e: SQLException =>
    val cause = e.getNextException
    if (cause != null && e.getCause != cause) {
      if (e.getCause == null) {
        e.initCause(cause)
      } else {
        e.addSuppressed(cause)
      }
    }
    throw e
} finally {
  if (!committed) {
    // The stage must fail.  We got here through an exception path, so
    // let the exception through unless rollback() or close() want to
    // tell the user about another problem.
    if (supportsTransactions) {
      conn.rollback()
    }
    conn.close()
  } else {
    // The stage must succeed.  We cannot propagate any exception close() might throw.
    try {
      conn.close()
    } catch {
      case e: Exception => logger.warn("Transaction succeeded, but closing failed", e)
    }
  }
}

}

/**
* Saves the RDD to the database in a single transaction.
*/
def saveTable(
df: DataFrame,
url: String,
table: String,
options: JDBCOptions) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}

val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
df.foreachPartition(iterator => savePartition(
  getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
)

}

private def makeSetter(
conn: Connection,
dialect: JdbcDialect,
dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getInt(pos))

case LongType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setLong(pos + 1, row.getLong(pos))

case DoubleType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setDouble(pos + 1, row.getDouble(pos))

case FloatType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setFloat(pos + 1, row.getFloat(pos))

case ShortType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setInt(pos + 1, row.getShort(pos))

case ByteType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setInt(pos + 1, row.getByte(pos))

case BooleanType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setBoolean(pos + 1, row.getBoolean(pos))

case StringType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setString(pos + 1, row.getString(pos))

case BinaryType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

case TimestampType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

case DateType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

case t: DecimalType =>
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

case ArrayType(et, _) =>
  // remove type length parameters from end of type name
  val typeName = getJdbcType(et, dialect).databaseTypeDefinition
    .toLowerCase.split("\\(")(0)
  (stmt: PreparedStatement, row: Row, pos: Int) =>
    val array = conn.createArrayOf(
      typeName,
      row.getSeq[AnyRef](pos).toArray)
    stmt.setArray(pos + 1, array)

case _ =>
  (_: PreparedStatement, _: Row, pos: Int) =>
    throw new IllegalArgumentException(
      s"Can't translate non-null value for field $pos")

}
}

用法:

val url = s"jdbc:mysql://$host/$database?useUnicode=true&characterEncoding=UTF-8"

val parameters: Map[String, String] = Map(
"url" -> url,
"dbtable" -> table,
"driver" -> "com.mysql.jdbc.Driver",
"numPartitions" -> numPartitions.toString,
"user" -> user,
"password" -> password
)
val options = new JDBCOptions(parameters)

for (d <- data) {
UpdateJdbcUtils.saveTable(d, url, table, options)
}

ps:注意死锁,不要频繁更新数据,只是在紧急情况下重新运行时使用,我想这就是spark不支持这个官方的原因。

相关问题