66import json
77import urllib .parse
88from io import BytesIO
9+ from json import dumps
910from typing import Any , Optional , Union
1011from typing_extensions import Literal
1112
1213from graphql import ExecutionResult
13- from webob import Request , Response
14+ from urllib3 import encode_multipart_formdata
1415
1516from graphql_server .http import GraphQLHTTPResponse
1617from graphql_server .http .ides import GraphQL_IDE
1718from graphql_server .webob import GraphQLView as BaseGraphQLView
1819from tests .http .context import get_context
1920from tests .views .schema import Query , schema
21+ from webob import Request , Response
2022
21- from .base import JSON , HttpClient , Response as ClientResponse , ResultOverrideFunction
23+ from .base import JSON , HttpClient , ResultOverrideFunction
24+ from .base import Response as ClientResponse
2225
2326
2427class GraphQLView (BaseGraphQLView [dict [str , object ], object ]):
@@ -82,18 +85,16 @@ async def _graphql_request(
8285
8386 url = "/graphql"
8487
85- if body and files :
86- body .update ({name : (file , name ) for name , file in files .items ()})
88+ headers = self ._get_headers (method = method , headers = headers , files = files )
8789
8890 if method == "get" :
8991 body_encoded = urllib .parse .urlencode (body or {})
9092 url = f"{ url } ?{ body_encoded } "
91- else :
92- if body :
93- data = body if files else json .dumps (body )
94- kwargs ["body" ] = data
95-
96- headers = self ._get_headers (method = method , headers = headers , files = files )
93+ elif body :
94+ if files :
95+ header_pairs , body = create_multipart_request_body (body , files )
96+ headers = dict (header_pairs )
97+ kwargs ["body" ] = body
9798
9899 return await self .request (url , method , headers = headers , ** kwargs )
99100
@@ -104,9 +105,11 @@ def _do_request(
104105 headers : Optional [dict [str , str ]] = None ,
105106 ** kwargs : Any ,
106107 ) -> ClientResponse :
107- body = kwargs .get ("body" , None )
108+ body = kwargs .pop ("body" , None )
109+ if isinstance (body , dict ):
110+ body = json .dumps (body ).encode ("utf-8" )
108111 req = Request .blank (
109- url , method = method .upper (), headers = headers or {}, body = body
112+ url , method = method .upper (), headers = headers or {}, body = body , ** kwargs
110113 )
111114 resp = self .view .dispatch_request (req )
112115 return ClientResponse (
@@ -139,5 +142,26 @@ async def post(
139142 json : Optional [JSON ] = None ,
140143 headers : Optional [dict [str , str ]] = None ,
141144 ) -> ClientResponse :
142- body = json if json is not None else data
145+ body = dumps ( json ). encode ( "utf-8" ) if json is not None else data
143146 return await self .request (url , "post" , headers = headers , body = body )
147+
148+
149+ def create_multipart_request_body (
150+ body : dict [str , object ], files : dict [str , BytesIO ]
151+ ) -> tuple [list [tuple [str , str ]], bytes ]:
152+ fields = {
153+ "operations" : body ["operations" ],
154+ "map" : body ["map" ],
155+ }
156+
157+ for filename , data in files .items ():
158+ fields [filename ] = (filename , data .read ().decode (), "text/plain" )
159+
160+ request_body , content_type_header = encode_multipart_formdata (fields )
161+
162+ headers = [
163+ ("Content-Type" , content_type_header ),
164+ ("Content-Length" , f"{ len (request_body )} " ),
165+ ]
166+
167+ return headers , request_body
0 commit comments