Подтвердить что ты не робот

Изменение формы/поворот данных в Spark RDD и/или Spark DataFrames

У меня есть некоторые данные в следующем формате (RDD или Spark DataFrame):

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

 rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

# convert to a Spark DataFrame                    
schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlContext.createDataFrame(rdd, schema)

Что я хотел бы сделать, так это "переформатировать" данные, преобразовать определенные строки в стране (в частности, США, Великобритания и ЦС) в столбцы:

ID    Age  US  UK  CA  
'X01'  41  3   1   2  
'X02'  72  4   6   7   

По сути, мне нужно что-то в строках рабочего процесса Python pivot:

categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID', 
                                                  columns = 'Country',
                                                  values = 'Score')

Мой набор данных довольно большой, поэтому я действительно не могу collect() и глотать данные в памяти, чтобы выполнить преобразование в самом Python. Есть ли способ конвертировать Python .pivot() в вызывающую функцию при сопоставлении RDD или Spark DataFrame? Любая помощь будет оценена!

4b9b3361

Ответ 1

Начиная с Spark 1.6 вы можете использовать pivot функцию GroupedData и предоставить агрегированное выражение.

pivoted = (df
    .groupBy("ID", "Age")
    .pivot(
        "Country",
        ['US', 'UK', 'CA'])  # Optional list of levels
    .sum("Score"))  # alternatively you can use .agg(expr))
pivoted.show()

## +---+---+---+---+---+
## | ID|Age| US| UK| CA|
## +---+---+---+---+---+
## |X01| 41|  3|  1|  2|
## |X02| 72|  4|  6|  7|
## +---+---+---+---+---+

Уровни могут быть опущены, но при условии, что они могут повысить производительность и служить внутренним фильтром.

Этот метод по-прежнему относительно медленный, но, конечно, бит вручную передает данные вручную между JVM и Python.

Ответ 2

Прежде всего, это, вероятно, не очень хорошая идея, потому что вы не получаете никакой дополнительной информации, но вы привязываетесь к фиксированной схеме (то есть вам нужно знать, сколько стран вы ожидаете, и, конечно же, дополнительная страна означает изменение кода)

Сказав это, это проблема SQL, которая показана ниже. Но в случае, если вы предполагаете, что это не слишком "программное обеспечение" (серьезно, я это слышал!!), тогда вы можете отсылать первое решение.

Решение 1:

def reshape(t):
    out = []
    out.append(t[0])
    out.append(t[1])
    for v in brc.value:
        if t[2] == v:
            out.append(t[3])
        else:
            out.append(0)
    return (out[0],out[1]),(out[2],out[3],out[4],out[5])
def cntryFilter(t):
    if t[2] in brc.value:
        return t
    else:
        pass

def addtup(t1,t2):
    j=()
    for k,v in enumerate(t1):
        j=j+(t1[k]+t2[k],)
    return j

def seq(tIntrm,tNext):
    return addtup(tIntrm,tNext)

def comb(tP,tF):
    return addtup(tP,tF)


countries = ['CA', 'UK', 'US', 'XX']
brc = sc.broadcast(countries)
reshaped = calls.filter(cntryFilter).map(reshape)
pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1)
for i in pivot.collect():
    print i

Теперь, решение 2: Конечно, лучше, поскольку SQL является правильным инструментом для этого

callRow = calls.map(lambda t:   

Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3]))
callsDF = ssc.createDataFrame(callRow)
callsDF.printSchema()
callsDF.registerTempTable("calls")
res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\
                    from (select userid,age,\
                                  case when country='CA' then nbrCalls else 0 end ca,\
                                  case when country='UK' then nbrCalls else 0 end uk,\
                                  case when country='US' then nbrCalls else 0 end us,\
                                  case when country='XX' then nbrCalls else 0 end xx \
                             from calls) x \
                     group by userid,age")
res.show()

настроено:

data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)]
 calls = sc.parallelize(data,1)
countries = ['CA', 'UK', 'US', 'XX']

Результат:

Из первого решения

(('X02', 72), (7, 6, 4, 8)) 
(('X01', 41), (2, 1, 3, 0))

Из второго решения:

root  |-- age: long (nullable = true)  
      |-- country: string (nullable = true)  
      |-- nbrCalls: long (nullable = true)  
      |-- userid: string (nullable = true)

userid age ca uk us xx 
 X02    72  7  6  4  8  
 X01    41  2  1  3  0

Скажите, пожалуйста, если это работает, или нет:)

Лучший Аян

Ответ 3

Здесь используется собственный подход Spark, который не затрудняет имена столбцов. Он основан на aggregateByKey и использует словарь для сбора столбцов, которые отображаются для каждого ключа. Затем мы собираем все имена столбцов, чтобы создать окончательный файл данных. [Предварительная версия использовала jsonRDD после испускания словаря для каждой записи, но это более эффективно.] Ограничение на конкретный список столбцов или исключение таких, как XX, было бы легкой модификацией.

Производительность кажется хорошей даже на довольно больших таблицах. Я использую вариацию, которая подсчитывает количество раз, каждое из которых имеет переменное число событий для каждого идентификатора, генерируя один столбец для каждого типа события. Код в основном тот же, за исключением того, что он использует коллекции .Counter вместо dict в seqFn для подсчета вхождений.

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    return u

def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    df
    .rdd
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)
columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)
result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c) for c in columns]),
    schema=StructType(
        [StructField('ID', StringType())] + 
        [StructField(c, IntegerType()) for c in columns]
    )
)
result.show()

Выдает:

ID  CA UK US XX  
X02 7  6  4  8   
X01 2  1  3  null

Ответ 4

Итак, сначала я должен был внести эту коррекцию в ваш RDD (который соответствует вашему фактическому результату):

rdd = sc.parallelize([('X01',41,'US',3),
                      ('X01',41,'UK',1),
                      ('X01',41,'CA',2),
                      ('X02',72,'US',4),
                      ('X02',72,'UK',6),
                      ('X02',72,'CA',7),
                      ('X02',72,'XX',8)])

Как только я сделал это исправление, это сделало трюк:

df.select($"ID", $"Age").groupBy($"ID").agg($"ID", first($"Age") as "Age")
.join(
    df.select($"ID" as "usID", $"Country" as "C1",$"Score" as "US"),
    $"ID" === $"usID" and $"C1" === "US"
)
.join(
    df.select($"ID" as "ukID", $"Country" as "C2",$"Score" as "UK"),
    $"ID" === $"ukID" and $"C2" === "UK"
)
.join(
    df.select($"ID" as "caID", $"Country" as "C3",$"Score" as "CA"), 
    $"ID" === $"caID" and $"C3" === "CA"
)
.select($"ID",$"Age",$"US",$"UK",$"CA")

Не так элегантно, как ваш стержень, конечно.

Ответ 5

Просто некоторые комментарии к очень полезному ответу patricksurry:

  • Отсутствует столбец Age, поэтому просто добавьте u [ "Age" ] = v.Age к функции seqPivot
  • оказалось, что обе петли над элементами столбцов давали элементы в другом порядке. Значения столбцов были правильными, но не именами их. Чтобы избежать такого поведения, просто закажите список столбцов.

Вот немного модифицированный код:

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

# u is a dictionarie
# v is a Row
def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    # In the original posting the Age column was not specified
    u["Age"] = v.Age
    return u

# u1
# u2
def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    rdd
    .map(lambda row: Row(ID=row[0], Age=row[1], Country=row[2],  Score=row[3]))
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)

columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)

columns_ord = sorted(columns)

result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c, None) for c in columns_ord]),
        schema=StructType(
            [StructField('ID', StringType())] + 
            [StructField(c, IntegerType()) for c in columns_ord]
        )
    )

print result.show()

Наконец, вывод должен быть

+---+---+---+---+---+----+
| ID|Age| CA| UK| US|  XX|
+---+---+---+---+---+----+
|X02| 72|  7|  6|  4|   8|
|X01| 41|  2|  1|  3|null|
+---+---+---+---+---+----+

Ответ 6

Существует JIRA in Hive для PIVOT, чтобы сделать это изначально, без огромного оператора CASE для каждого значения:

https://issues.apache.org/jira/browse/HIVE-3776

Прошу проголосовать за JIRA, чтобы она была реализована раньше. Как только он попадает в Hive SQL, Spark обычно не слишком сильно отстает, и в конечном итоге он будет реализован и в Spark.