Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ private static String buildLogoutResponseUrl(SamlRealm realm, SamlLogoutRequestH
}

private void findAndInvalidateTokens(SamlRealm realm, SamlLogoutRequestHandler.Result result, ActionListener<Integer> listener) {
final Map<String, Object> tokenMetadata = realm.createTokenMetadata(result.getNameId(), result.getSession());
final Map<String, Object> tokenMetadata = realm.createTokenMetadata(result.getNameId(), result.getSession(), null);
if (Strings.isNullOrEmpty((String) tokenMetadata.get(SamlRealm.TOKEN_METADATA_NAMEID_VALUE))) {
// If we don't have a valid name-id to match against, don't do anything
LOGGER.debug("Logout request [{}] has no NameID value, so cannot invalidate any sessions", result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,20 @@ public class SamlAttributes implements Releasable {

private final SamlNameId name;
private final String session;
private final String inResponseTo;
private final List<SamlAttribute> attributes;
private final List<SamlPrivateAttribute> privateAttributes;

SamlAttributes(SamlNameId name, String session, List<SamlAttribute> attributes, List<SamlPrivateAttribute> privateAttributes) {
SamlAttributes(
SamlNameId name,
String session,
String inResponseTo,
List<SamlAttribute> attributes,
List<SamlPrivateAttribute> privateAttributes
) {
this.name = name;
this.session = session;
this.inResponseTo = inResponseTo;
this.attributes = attributes;
this.privateAttributes = privateAttributes;
}
Expand Down Expand Up @@ -89,9 +97,24 @@ String session() {
return session;
}

String inResponseTo() {
return inResponseTo;
}

@Override
public String toString() {
return getClass().getSimpleName() + "(" + name + ")[" + session + "]{" + attributes + "}{" + privateAttributes + "}";
return getClass().getSimpleName()
+ "("
+ name
+ ")["
+ session
+ "]["
+ inResponseTo
+ "]{"
+ attributes
+ "}{"
+ privateAttributes
+ "}";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ private SamlAttributes authenticateResponse(Element element, Collection<String>
final Assertion assertion = details.v1();
final SamlNameId nameId = SamlNameId.forSubject(assertion.getSubject());
final String session = getSessionIndex(assertion);
final SamlAttributes samlAttributes = buildSamlAttributes(nameId, session, details.v2());
final SamlAttributes samlAttributes = buildSamlAttributes(nameId, session, response.getInResponseTo(), details.v2());
if (logger.isTraceEnabled()) {
StringBuilder sb = new StringBuilder();
sb.append("The SAML Assertion contained the following attributes: \n");
Expand All @@ -141,7 +141,7 @@ private SamlAttributes authenticateResponse(Element element, Collection<String>
return samlAttributes;
}

private SamlAttributes buildSamlAttributes(SamlNameId nameId, String session, List<Attribute> attributes) {
private SamlAttributes buildSamlAttributes(SamlNameId nameId, String session, String inResponseTo, List<Attribute> attributes) {
List<SamlAttributes.SamlAttribute> samlAttributes = new ArrayList<>();
List<SamlAttributes.SamlPrivateAttribute> samlPrivateAttributes = new ArrayList<>();
for (Attribute attribute : attributes) {
Expand All @@ -151,7 +151,7 @@ private SamlAttributes buildSamlAttributes(SamlNameId nameId, String session, Li
samlAttributes.add(new SamlAttributes.SamlAttribute(attribute));
}
}
return new SamlAttributes(nameId, session, samlAttributes, samlPrivateAttributes);
return new SamlAttributes(nameId, session, inResponseTo, samlAttributes, samlPrivateAttributes);
}

private static String getSessionIndex(Assertion assertion) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ public final class SamlRealm extends Realm implements Releasable {
public static final String TOKEN_METADATA_NAMEID_SP_QUALIFIER = "saml_nameid_sp_qual";
public static final String TOKEN_METADATA_NAMEID_SP_PROVIDED_ID = "saml_nameid_sp_id";
public static final String TOKEN_METADATA_SESSION = "saml_session";
public static final String TOKEN_METADATA_IN_RESPONSE_TO = "saml_in_response_to";
public static final String TOKEN_METADATA_REALM = "saml_realm";

public static final String PRIVATE_ATTRIBUTES_METADATA = "saml_private_attributes";
Expand Down Expand Up @@ -588,7 +589,7 @@ private void buildUser(SamlAttributes attributes, ActionListener<AuthenticationR
return;
}

final Map<String, Object> tokenMetadata = createTokenMetadata(attributes.name(), attributes.session());
final Map<String, Object> tokenMetadata = createTokenMetadata(attributes.name(), attributes.session(), attributes.inResponseTo());
final Map<String, List<SecureString>> privateAttributesMetadata = attributes.privateAttributes()
.stream()
.collect(Collectors.toMap(SamlPrivateAttribute::name, SamlPrivateAttribute::values));
Expand Down Expand Up @@ -639,7 +640,7 @@ private void buildUser(SamlAttributes attributes, ActionListener<AuthenticationR
}));
}

public Map<String, Object> createTokenMetadata(SamlNameId nameId, String session) {
public Map<String, Object> createTokenMetadata(SamlNameId nameId, String session, String inResponseTo) {
final Map<String, Object> tokenMeta = new HashMap<>();
if (nameId != null) {
tokenMeta.put(TOKEN_METADATA_NAMEID_VALUE, nameId.value);
Expand All @@ -655,6 +656,9 @@ public Map<String, Object> createTokenMetadata(SamlNameId nameId, String session
tokenMeta.put(TOKEN_METADATA_NAMEID_SP_PROVIDED_ID, null);
}
tokenMeta.put(TOKEN_METADATA_SESSION, session);
if (inResponseTo != null) {
tokenMeta.put(TOKEN_METADATA_IN_RESPONSE_TO, inResponseTo);
}
tokenMeta.put(TOKEN_METADATA_REALM, name());
return tokenMeta;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ private TokenService.CreateTokenResult storeToken(byte[] userTokenBytes, byte[]
.user(new User("bob"))
.realmRef(new RealmRef("native", NativeRealmSettings.TYPE, "node01"))
.build(false);
final Map<String, Object> metadata = samlRealm.createTokenMetadata(nameId, session);
final Map<String, Object> metadata = samlRealm.createTokenMetadata(nameId, session, null);
final PlainActionFuture<TokenService.CreateTokenResult> future = new PlainActionFuture<>();
tokenService.createOAuth2Tokens(userTokenBytes, refreshTokenBytes, authentication, authentication, metadata, future);
return future.actionGet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ public void testLogoutInvalidatesToken() throws Exception {
final Authentication.RealmRef realmRef = new Authentication.RealmRef(samlRealm.name(), SingleSpSamlRealmSettings.TYPE, "node01");
final Map<String, Object> tokenMetadata = samlRealm.createTokenMetadata(
new SamlNameId(NameID.TRANSIENT, nameId, null, null, null),
session
session,
null
);
final Authentication authentication = Authentication.newRealmAuthentication(user, realmRef);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ public void testToString() {
final String nameFormat = randomFrom(NameID.TRANSIENT, NameID.PERSISTENT, NameID.EMAIL);
final String nameId = randomIdentifier();
final String session = randomAlphaOfLength(16);
final String inResponseTo = randomAlphanumericOfLength(16);
final SamlAttributes attributes = new SamlAttributes(
new SamlNameId(nameFormat, nameId, null, null, null),
session,
inResponseTo,
List.of(
new SamlAttributes.SamlAttribute("urn:oid:0.9.2342.19200300.100.1.1", null, List.of("peter.ng")),
new SamlAttributes.SamlAttribute("urn:oid:2.5.4.3", "name", List.of("Peter Ng")),
Expand All @@ -46,6 +48,8 @@ public void testToString() {
+ ("NameId(" + nameFormat + ")=" + nameId)
+ ")["
+ session
+ "]["
+ inResponseTo
+ "]{["
+ "urn:oid:0.9.2342.19200300.100.1.1=[peter.ng](len=1)"
+ ", "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.test.ActionListenerUtils.anyActionListener;
import static org.elasticsearch.test.TestMatchers.throwableWithMessage;
import static org.elasticsearch.xpack.security.authc.saml.SamlRealm.CONTEXT_TOKEN_DATA;
import static org.elasticsearch.xpack.security.authc.saml.SamlRealm.PRIVATE_ATTRIBUTES_METADATA;
import static org.elasticsearch.xpack.security.authc.saml.SamlRealm.TOKEN_METADATA_IN_RESPONSE_TO;
import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
Expand All @@ -101,6 +103,7 @@
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.typeCompatibleWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -580,6 +583,7 @@ private AuthenticationResult<User> performAuthentication(
final String nameIdValue = principalIsEmailAddress ? "clint.barton@shield.gov" : "clint.barton";
final String uidValue = principalIsEmailAddress ? "cbarton@shield.gov" : "cbarton";
final String realmType = SingleSpSamlRealmSettings.TYPE;
final String inResponseTo = "_request_id_12345";

final RealmConfig.RealmIdentifier realmIdentifier = new RealmConfig.RealmIdentifier("mock", "mock_lookup");
final MockLookupRealm lookupRealm = new MockLookupRealm(
Expand Down Expand Up @@ -669,6 +673,7 @@ private AuthenticationResult<User> performAuthentication(
final SamlAttributes attributes = new SamlAttributes(
new SamlNameId(NameIDType.PERSISTENT, nameIdValue, idp.getEntityID(), sp.getEntityId(), null),
randomAlphaOfLength(16),
inResponseTo,
List.of(
new SamlAttributes.SamlAttribute("urn:oid:0.9.2342.19200300.100.1.1", "uid", Collections.singletonList(uidValue)),
new SamlAttributes.SamlAttribute("urn:oid:1.3.6.1.4.1.5923.1.5.1.1", "groups", groups),
Expand Down Expand Up @@ -698,6 +703,11 @@ public void onResponse(AuthenticationResult<User> result) {
@SuppressWarnings("unchecked")
var metadata = (Map<String, List<SecureString>>) result.getMetadata().get(PRIVATE_ATTRIBUTES_METADATA);
secureAttributes.forEach((name, value) -> assertThat(metadata.get(name), equalTo(value)));
Object tokenContext = result.getMetadata().get(CONTEXT_TOKEN_DATA);
assertThat(tokenContext, notNullValue());
assertThat(tokenContext.getClass(), typeCompatibleWith(Map.class));
Object returnedInResponseTo = ((Map<?, ?>) tokenContext).get(TOKEN_METADATA_IN_RESPONSE_TO);
assertThat(returnedInResponseTo, equalTo(inResponseTo));
}
super.onResponse(result);
}
Expand Down Expand Up @@ -778,6 +788,7 @@ private List<String> performAttributeSelectionWithSplit(String delimiter, String
final SamlAttributes attributes = new SamlAttributes(
new SamlNameId(NameIDType.TRANSIENT, randomAlphaOfLength(24), null, null, null),
randomAlphaOfLength(16),
randomAlphaOfLength(16),
List.of(
new SamlAttributes.SamlAttribute(
"departments",
Expand Down Expand Up @@ -850,6 +861,7 @@ public void testAttributeSelectionWithSplitAndListThrowsSecurityException() {
final SamlAttributes attributes = new SamlAttributes(
new SamlNameId(NameIDType.TRANSIENT, randomAlphaOfLength(24), null, null, null),
randomAlphaOfLength(16),
randomAlphaOfLength(16),
List.of(
new SamlAttributes.SamlAttribute(
"departments",
Expand Down Expand Up @@ -883,6 +895,7 @@ public void testAttributeSelectionWithRegex() {
final SamlAttributes attributes = new SamlAttributes(
new SamlNameId(NameIDType.TRANSIENT, randomAlphaOfLength(24), null, null, null),
randomAlphaOfLength(16),
randomAlphaOfLength(16),
List.of(
new SamlAttributes.SamlAttribute(
"urn:oid:0.9.2342.19200300.100.1.3",
Expand Down Expand Up @@ -1053,6 +1066,7 @@ public void testNonMatchingPrincipalPatternThrowsSamlException() throws Exceptio
final SamlAttributes attributes = new SamlAttributes(
new SamlNameId(NameIDType.TRANSIENT, randomAlphaOfLength(12), null, null, null),
randomAlphaOfLength(16),
randomAlphaOfLength(16),
List.of(new SamlAttributes.SamlAttribute("urn:oid:0.9.2342.19200300.100.1.3", "mail", Collections.singletonList(mail))),
List.of()
);
Expand Down